diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java new file mode 100644 index 000000000000..0374846d7167 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. + * + * Each tuple has two parts: [offsets] [values] + * + * In the `offsets` region, we store 4 bytes per element, represents the start address of this + * element in `values` region. We can get the length of this element by subtracting next offset. + * Note that offset can by negative which means this element is null. + * + * In the `values` region, we store the content of elements. As we can get length info, so elements + * can be variable-length. + * + * Note that when we write out this array, we should write out the `numElements` at first 4 bytes, + * then follows content. When we read in an array, we should read first 4 bytes as `numElements` + * and take the rest as content. + * + * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. + */ +// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. +public class UnsafeArrayData extends ArrayData { + + private Object baseObject; + private long baseOffset; + + // The number of elements in this array + private int numElements; + + // The size of this array's backing data, in bytes + private int sizeInBytes; + + private int getElementOffset(int ordinal) { + return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + } + + private int getElementSize(int offset, int ordinal) { + if (ordinal == numElements - 1) { + return sizeInBytes - offset; + } else { + return Math.abs(getElementOffset(ordinal + 1)) - offset; + } + } + + private void assertIndexIsValid(int ordinal) { + assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0"; + assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements; + } + + /** + * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until + * `pointTo()` has been called, since the value returned by this constructor is equivalent + * to a null pointer. + */ + public UnsafeArrayData() { } + + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + + @Override + public int numElements() { return numElements; } + + /** + * Update this UnsafeArrayData to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param sizeInBytes the size of this row's backing data, in bytes + */ + public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) { + assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + this.numElements = numElements; + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.sizeInBytes = sizeInBytes; + } + + @Override + public boolean isNullAt(int ordinal) { + assertIndexIsValid(ordinal); + return getElementOffset(ordinal) < 0; + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } + } + + @Override + public boolean getBoolean(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return false; + return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + } + + @Override + public byte getByte(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + } + + @Override + public short getShort(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + } + + @Override + public int getInt(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + } + + @Override + public long getLong(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + } + + @Override + public float getFloat(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + } + + @Override + public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + + if (precision <= Decimal.MAX_LONG_DIGITS()) { + final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Decimal.apply(value, precision, scale); + } else { + final byte[] bytes = getBinary(ordinal); + final BigInteger bigInteger = new BigInteger(bytes); + final BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + } + } + + @Override + public UTF8String getUTF8String(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + } + + @Override + public byte[] getBinary(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size); + return bytes; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long microseconds = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + return new CalendarInterval(months, microseconds); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + final UnsafeRow row = new UnsafeRow(); + row.pointTo(baseObject, baseOffset + offset, numFields, size); + return row; + } + + @Override + public ArrayData getArray(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + } + + @Override + public MapData getMap(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + } + + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeArrayData) { + UnsafeArrayData o = (UnsafeArrayData) other; + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } + + @Override + public UnsafeArrayData copy() { + UnsafeArrayData arrayCopy = new UnsafeArrayData(); + final byte[] arrayDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + arrayDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + return arrayCopy; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java new file mode 100644 index 000000000000..46216054ab38 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.types.MapData; + +/** + * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. + * + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. + */ +public class UnsafeMapData extends MapData { + + public final UnsafeArrayData keys; + public final UnsafeArrayData values; + // The number of elements in this array + private int numElements; + // The size of this array's backing data, in bytes + private int sizeInBytes; + + public int getSizeInBytes() { return sizeInBytes; } + + public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) { + assert keys.numElements() == values.numElements(); + this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes(); + this.numElements = keys.numElements(); + this.keys = keys; + this.values = values; + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public ArrayData keyArray() { + return keys; + } + + @Override + public ArrayData valueArray() { + return values; + } + + @Override + public UnsafeMapData copy() { + return new UnsafeMapData(keys.copy(), values.copy()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java new file mode 100644 index 000000000000..b521b703389d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.PlatformDependent; + +public class UnsafeReaders { + + public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { + // Read the number of elements from first 4 bytes. + final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final UnsafeArrayData array = new UnsafeArrayData(); + // Skip the first 4 bytes. + array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); + return array; + } + + public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { + // Read the number of elements from first 4 bytes. + final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + // Read the numBytes of key array in second 4 bytes. + final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int valueArraySize = numBytes - 8 - keyArraySize; + + final UnsafeArrayData keyArray = new UnsafeArrayData(); + keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize); + + final UnsafeArrayData valueArray = new UnsafeArrayData(); + valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize); + + return new UnsafeMapData(keyArray, valueArray); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1b475b249274..fead1f3a5990 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -291,6 +291,10 @@ public Object get(int ordinal, DataType dataType) { return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof MapType) { + return getMap(ordinal); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -346,7 +350,6 @@ public double getDouble(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } @@ -362,7 +365,6 @@ public Decimal getDecimal(int ordinal, int precision, int scale) { @Override public UTF8String getUTF8String(int ordinal) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); @@ -372,7 +374,6 @@ public UTF8String getUTF8String(int ordinal) { @Override public byte[] getBinary(int ordinal) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } else { @@ -410,7 +411,6 @@ public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(ordinal); final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); @@ -420,11 +420,33 @@ public UnsafeRow getStruct(int ordinal, int numFields) { } } + @Override + public ArrayData getArray(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + } + } + + @Override + public MapData getMap(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + } + } + /** * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal * byte array rather than referencing data stored in a data page. - *
- * This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
public UnsafeRow copy() {
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index f43a285cd6ca..31928731545d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -19,6 +19,7 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.MapData;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
@@ -185,4 +186,74 @@ public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInter
return 16;
}
}
+
+ public static class ArrayWriter {
+
+ public static int getSize(UnsafeArrayData input) {
+ // we need extra 4 bytes the store the number of elements in this array.
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.getSizeInBytes() + 4);
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayData input) {
+ final int numBytes = input.getSizeInBytes() + 4;
+ final long offset = target.getBaseOffset() + cursor;
+
+ // write the number of elements into first 4 bytes.
+ PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements());
+
+ // zero-out the padding bytes
+ if ((numBytes & 0x07) > 0) {
+ PlatformDependent.UNSAFE.putLong(
+ target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
+ }
+
+ // Write the bytes to the variable length portion.
+ input.writeToMemory(target.getBaseObject(), offset + 4);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ }
+ }
+
+ public static class MapWriter {
+
+ public static int getSize(UnsafeMapData input) {
+ // we need extra 8 bytes to store number of elements and numBytes of key array.
+ final int sizeInBytes = 4 + 4 + input.getSizeInBytes();
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes);
+ }
+
+ public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) {
+ final long offset = target.getBaseOffset() + cursor;
+ final UnsafeArrayData keyArray = input.keys;
+ final UnsafeArrayData valueArray = input.values;
+ final int keysNumBytes = keyArray.getSizeInBytes();
+ final int valuesNumBytes = valueArray.getSizeInBytes();
+ final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
+
+ // write the number of elements into first 4 bytes.
+ PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements());
+ // write the numBytes of key array into second 4 bytes.
+ PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes);
+
+ // zero-out the padding bytes
+ if ((numBytes & 0x07) > 0) {
+ PlatformDependent.UNSAFE.putLong(
+ target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
+ }
+
+ // Write the bytes of key array to the variable length portion.
+ keyArray.writeToMemory(target.getBaseObject(), offset + 8);
+
+ // Write the bytes of value array to the variable length portion.
+ valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes);
+
+ // Set the fixed length portion.
+ target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+
+ return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ }
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
new file mode 100644
index 000000000000..0e8e405d055d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A set of helper methods to write data into the variable length portion.
+ */
+public class UnsafeWriters {
+ public static void writeToMemory(
+ Object inputObject,
+ long inputOffset,
+ Object targetObject,
+ long targetOffset,
+ int numBytes) {
+
+ // zero-out the padding bytes
+// if ((numBytes & 0x07) > 0) {
+// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L);
+// }
+
+ // Write the UnsafeData to the target memory.
+ PlatformDependent.copyMemory(
+ inputObject,
+ inputOffset,
+ targetObject,
+ targetOffset,
+ numBytes
+ );
+ }
+
+ public static int getRoundedSize(int size) {
+ //return ByteArrayMethods.roundNumberOfBytesToNearestWord(size);
+ // todo: do word alignment
+ return size;
+ }
+
+ /** Writer for Decimal with precision larger than 18. */
+ public static class DecimalWriter {
+
+ public static int getSize(Decimal input) {
+ return 16;
+ }
+
+ public static int write(Object targetObject, long targetOffset, Decimal input) {
+ final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+ final int numBytes = bytes.length;
+ assert(numBytes <= 16);
+
+ // zero-out the bytes
+ PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L);
+ PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L);
+
+ // Write the bytes to the variable length portion.
+ PlatformDependent.copyMemory(bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ targetObject,
+ targetOffset,
+ numBytes);
+
+ return 16;
+ }
+ }
+
+ /** Writer for UTF8String. */
+ public static class UTF8StringWriter {
+
+ public static int getSize(UTF8String input) {
+ return getRoundedSize(input.numBytes());
+ }
+
+ public static int write(Object targetObject, long targetOffset, UTF8String input) {
+ final int numBytes = input.numBytes();
+
+ // Write the bytes to the variable length portion.
+ writeToMemory(input.getBaseObject(), input.getBaseOffset(),
+ targetObject, targetOffset, numBytes);
+
+ return getRoundedSize(numBytes);
+ }
+ }
+
+ /** Writer for binary (byte array) type. */
+ public static class BinaryWriter {
+
+ public static int getSize(byte[] input) {
+ return getRoundedSize(input.length);
+ }
+
+ public static int write(Object targetObject, long targetOffset, byte[] input) {
+ final int numBytes = input.length;
+
+ // Write the bytes to the variable length portion.
+ writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET,
+ targetObject, targetOffset, numBytes);
+
+ return getRoundedSize(numBytes);
+ }
+ }
+
+ /** Writer for UnsafeRow. */
+ public static class StructWriter {
+
+ public static int getSize(UnsafeRow input) {
+ return getRoundedSize(input.getSizeInBytes());
+ }
+
+ public static int write(Object targetObject, long targetOffset, UnsafeRow input) {
+ final int numBytes = input.getSizeInBytes();
+
+ // Write the bytes to the variable length portion.
+ writeToMemory(input.getBaseObject(), input.getBaseOffset(),
+ targetObject, targetOffset, numBytes);
+
+ return getRoundedSize(numBytes);
+ }
+ }
+
+ /** Writer for interval type. */
+ public static class IntervalWriter {
+
+ public static int getSize(UnsafeRow input) {
+ return 16;
+ }
+
+ public static int write(Object targetObject, long targetOffset, CalendarInterval input) {
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months);
+ PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds);
+
+ return 16;
+ }
+ }
+
+ /** Writer for UnsafeArrayData. */
+ public static class ArrayWriter {
+
+ public static int getSize(UnsafeArrayData input) {
+ // we need extra 4 bytes the store the number of elements in this array.
+ return getRoundedSize(input.getSizeInBytes() + 4);
+ }
+
+ public static int write(Object targetObject, long targetOffset, UnsafeArrayData input) {
+ final int numBytes = input.getSizeInBytes();
+
+ // write the number of elements into first 4 bytes.
+ PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements());
+
+ // Write the bytes to the variable length portion.
+ writeToMemory(input.getBaseObject(), input.getBaseOffset(),
+ targetObject, targetOffset + 4, numBytes);
+
+ return getRoundedSize(numBytes + 4);
+ }
+ }
+
+ public static class MapWriter {
+
+ public static int getSize(UnsafeMapData input) {
+ // we need extra 8 bytes to store number of elements and numBytes of key array.
+ return getRoundedSize(4 + 4 + input.getSizeInBytes());
+ }
+
+ public static int write(Object targetObject, long targetOffset, UnsafeMapData input) {
+ final UnsafeArrayData keyArray = input.keys;
+ final UnsafeArrayData valueArray = input.values;
+ final int keysNumBytes = keyArray.getSizeInBytes();
+ final int valuesNumBytes = valueArray.getSizeInBytes();
+ final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
+
+ // write the number of elements into first 4 bytes.
+ PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements());
+ // write the numBytes of key array into second 4 bytes.
+ PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes);
+
+ // Write the bytes of key array to the variable length portion.
+ writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(),
+ targetObject, targetOffset + 8, keysNumBytes);
+
+ // Write the bytes of value array to the variable length portion.
+ writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(),
+ targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes);
+
+ return getRoundedSize(numBytes);
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
new file mode 100644
index 000000000000..3caf0fb3410c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types._
+
+case class FromUnsafe(child: Expression) extends UnaryExpression
+ with ExpectsInputTypes with CodegenFallback {
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(ArrayType, StructType, MapType))
+
+ override def dataType: DataType = child.dataType
+
+ private def convert(value: Any, dt: DataType): Any = dt match {
+ case StructType(fields) =>
+ val row = value.asInstanceOf[UnsafeRow]
+ val result = new Array[Any](fields.length)
+ fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) =>
+ if (!row.isNullAt(i)) {
+ result(i) = convert(row.get(i, dt), dt)
+ }
+ }
+ new GenericInternalRow(result)
+
+ case ArrayType(elementType, _) =>
+ val array = value.asInstanceOf[UnsafeArrayData]
+ val length = array.numElements()
+ val result = new Array[Any](length)
+ var i = 0
+ while (i < length) {
+ if (!array.isNullAt(i)) {
+ result(i) = convert(array.get(i, elementType), elementType)
+ }
+ i += 1
+ }
+ new GenericArrayData(result)
+
+ case MapType(kt, vt, _) =>
+ val map = value.asInstanceOf[UnsafeMapData]
+ val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData]
+ val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData]
+ new ArrayBasedMapData(safeKeyArray, safeValueArray)
+
+ case _ => value
+ }
+
+ override def nullSafeEval(input: Any): Any = {
+ convert(input, dataType)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 83129dc12dff..79649741025a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -151,7 +151,15 @@ object FromUnsafeProjection {
* Returns an UnsafeProjection for given Array of DataTypes.
*/
def apply(fields: Seq[DataType]): Projection = {
- create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)))
+ create(fields.zipWithIndex.map(x => {
+ val b = new BoundReference(x._2, x._1, true)
+ // todo: this is quite slow, maybe remove this whole projection after remove generic getter of
+ // InternalRow?
+ b.dataType match {
+ case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
+ case _ => b
+ }
+ }))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 3177e6b75084..d58899140383 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -305,7 +305,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[Decimal].getName,
classOf[CalendarInterval].getName,
classOf[ArrayData].getName,
- classOf[MapData].getName
+ classOf[UnsafeArrayData].getName,
+ classOf[MapData].getName,
+ classOf[UnsafeMapData].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 6c9908604668..1d6a9c5e7abd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
/**
* Generates a [[Projection]] that returns an [[UnsafeRow]].
@@ -37,14 +38,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName
private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName
private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName
+ private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName
+ private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName
+
+ private val PlatformDependent = classOf[PlatformDependent].getName
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
- case t: AtomicType if !t.isInstanceOf[DecimalType] => true
+ case t: AtomicType => true
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case NullType => true
- case t: DecimalType => true
+ case t: ArrayType if canSupport(t.elementType) => true
+ case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
case _ => false
}
@@ -59,6 +65,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s" + (${ev.isNull} ? 0 : 16)"
case _: StructType =>
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
+ case _: ArrayType =>
+ s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))"
+ case _: MapType =>
+ s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))"
case _ => ""
}
@@ -95,8 +105,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
case CalendarIntervalType =>
s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
- case t: StructType =>
+ case _: StructType =>
s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ case _: ArrayType =>
+ s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+ case _: MapType =>
+ s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
@@ -148,7 +162,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$ret.pointTo(
$buffer,
- org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
+ $PlatformDependent.BYTE_ARRAY_OFFSET,
${expressions.size},
$numBytes);
int $cursor = $fixedSize;
@@ -237,7 +251,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
| $primitive.pointTo(
| $buffer,
- | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
+ | $PlatformDependent.BYTE_ARRAY_OFFSET,
| ${exprs.size},
| $numBytes);
| int $cursor = $fixedSize;
@@ -250,6 +264,303 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
GeneratedExpressionCode(code, isNull, primitive)
}
+ /**
+ * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
+ *
+ * @param ctx code generation context
+ * @param inputs could be the codes for expressions or input struct fields.
+ * @param inputTypes types of the inputs
+ */
+ private def createCodeForStruct2(
+ ctx: CodeGenContext,
+ inputs: Seq[GeneratedExpressionCode],
+ inputTypes: Seq[DataType]): GeneratedExpressionCode = {
+
+ val output = ctx.freshName("convertedStruct")
+ ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();")
+ val buffer = ctx.freshName("buffer")
+ ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
+ val numBytes = ctx.freshName("numBytes")
+ val cursor = ctx.freshName("cursor")
+
+ val convertedFields = inputTypes.zip(inputs).map { case (dt, input) =>
+ createConvertCode(ctx, input, dt)
+ }
+
+ val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
+ val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) =>
+ genAdditionalSize(dt, ev)
+ }.mkString("")
+
+ val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) =>
+ val update = genFieldWriter(ctx, dt, ev, output, i, cursor)
+ s"""
+ if (${ev.isNull}) {
+ $output.setNullAt($i);
+ } else {
+ $update;
+ }
+ """
+ }.mkString("\n")
+
+ val code = s"""
+ ${convertedFields.map(_.code).mkString("\n")}
+
+ final int $numBytes = $fixedSize $additionalSize;
+ if ($numBytes > $buffer.length) {
+ $buffer = new byte[$numBytes];
+ }
+
+ $output.pointTo(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET,
+ ${inputTypes.length},
+ $numBytes);
+
+ int $cursor = $fixedSize;
+
+ $fieldWriters
+ """
+ GeneratedExpressionCode(code, "false", output)
+ }
+
+ private def getWriter(dt: DataType) = dt match {
+ case StringType => classOf[UnsafeWriters.UTF8StringWriter].getName
+ case BinaryType => classOf[UnsafeWriters.BinaryWriter].getName
+ case CalendarIntervalType => classOf[UnsafeWriters.IntervalWriter].getName
+ case _: StructType => classOf[UnsafeWriters.StructWriter].getName
+ case _: ArrayType => classOf[UnsafeWriters.ArrayWriter].getName
+ case _: MapType => classOf[UnsafeWriters.MapWriter].getName
+ case _: DecimalType => classOf[UnsafeWriters.DecimalWriter].getName
+ }
+
+ private def createCodeForArray(
+ ctx: CodeGenContext,
+ input: GeneratedExpressionCode,
+ elementType: DataType): GeneratedExpressionCode = {
+ val output = ctx.freshName("convertedArray")
+ ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();")
+ val buffer = ctx.freshName("buffer")
+ ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
+ val outputIsNull = ctx.freshName("isNull")
+ val tmp = ctx.freshName("tmp")
+ val numElements = ctx.freshName("numElements")
+ val fixedSize = ctx.freshName("fixedSize")
+ val numBytes = ctx.freshName("numBytes")
+ val elements = ctx.freshName("elements")
+ val cursor = ctx.freshName("cursor")
+ val index = ctx.freshName("index")
+
+ val element = GeneratedExpressionCode(
+ code = "",
+ isNull = s"$tmp.isNullAt($index)",
+ primitive = s"${ctx.getValue(tmp, elementType, index)}"
+ )
+ val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType)
+
+ // go through the input array to calculate how many bytes we need.
+ val calculateNumBytes = elementType match {
+ case _ if (ctx.isPrimitiveType(elementType)) =>
+ // Should we do word align?
+ val elementSize = elementType.defaultSize
+ s"""
+ $numBytes += $elementSize * $numElements;
+ """
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+ s"""
+ $numBytes += 8 * $numElements;
+ """
+ case _ =>
+ val writer = getWriter(elementType)
+ val elementSize = s"$writer.getSize($elements[$index])"
+ val unsafeType = elementType match {
+ case _: StructType => "UnsafeRow"
+ case _: ArrayType => "UnsafeArrayData"
+ case _: MapType => "UnsafeMapData"
+ case _ => ctx.javaType(elementType)
+ }
+ val copy = elementType match {
+ // We reuse the buffer during conversion, need copy it before process next element.
+ case _: StructType | _: ArrayType | _: MapType => ".copy()"
+ case _ => ""
+ }
+
+ s"""
+ final $unsafeType[] $elements = new $unsafeType[$numElements];
+ for (int $index = 0; $index < $numElements; $index++) {
+ ${convertedElement.code}
+ if (!${convertedElement.isNull}) {
+ $elements[$index] = ${convertedElement.primitive}$copy;
+ $numBytes += $elementSize;
+ }
+ }
+ """
+ }
+
+ val writeElement = elementType match {
+ case _ if (ctx.isPrimitiveType(elementType)) =>
+ // Should we do word align?
+ val elementSize = elementType.defaultSize
+ s"""
+ $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ ${convertedElement.primitive});
+ $cursor += $elementSize;
+ """
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+ s"""
+ $PlatformDependent.UNSAFE.putLong(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ ${convertedElement.primitive}.toUnscaledLong());
+ $cursor += 8;
+ """
+ case _ =>
+ val writer = getWriter(elementType)
+ s"""
+ $cursor += $writer.write(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor,
+ $elements[$index]);
+ """
+ }
+
+ val checkNull = elementType match {
+ case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}"
+ case t: DecimalType => s"$elements[$index] == null" +
+ s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})"
+ case _ => s"$elements[$index] == null"
+ }
+
+ val code = s"""
+ ${input.code}
+ final boolean $outputIsNull = ${input.isNull};
+ if (!$outputIsNull) {
+ final ArrayData $tmp = ${input.primitive};
+ if ($tmp instanceof UnsafeArrayData) {
+ $output = (UnsafeArrayData) $tmp;
+ } else {
+ final int $numElements = $tmp.numElements();
+ final int $fixedSize = 4 * $numElements;
+ int $numBytes = $fixedSize;
+
+ $calculateNumBytes
+
+ if ($numBytes > $buffer.length) {
+ $buffer = new byte[$numBytes];
+ }
+
+ int $cursor = $fixedSize;
+ for (int $index = 0; $index < $numElements; $index++) {
+ if ($checkNull) {
+ // If element is null, write the negative value address into offset region.
+ $PlatformDependent.UNSAFE.putInt(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+ -$cursor);
+ } else {
+ $PlatformDependent.UNSAFE.putInt(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index,
+ $cursor);
+
+ $writeElement
+ }
+ }
+
+ $output.pointTo(
+ $buffer,
+ $PlatformDependent.BYTE_ARRAY_OFFSET,
+ $numElements,
+ $numBytes);
+ }
+ }
+ """
+ GeneratedExpressionCode(code, outputIsNull, output)
+ }
+
+ private def createCodeForMap(
+ ctx: CodeGenContext,
+ input: GeneratedExpressionCode,
+ keyType: DataType,
+ valueType: DataType): GeneratedExpressionCode = {
+ val output = ctx.freshName("convertedMap")
+ val outputIsNull = ctx.freshName("isNull")
+ val tmp = ctx.freshName("tmp")
+
+ val keyArray = GeneratedExpressionCode(
+ code = "",
+ isNull = "false",
+ primitive = s"$tmp.keyArray()"
+ )
+ val valueArray = GeneratedExpressionCode(
+ code = "",
+ isNull = "false",
+ primitive = s"$tmp.valueArray()"
+ )
+ val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType)
+ val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType)
+
+ val code = s"""
+ ${input.code}
+ final boolean $outputIsNull = ${input.isNull};
+ UnsafeMapData $output = null;
+ if (!$outputIsNull) {
+ final MapData $tmp = ${input.primitive};
+ if ($tmp instanceof UnsafeMapData) {
+ $output = (UnsafeMapData) $tmp;
+ } else {
+ ${convertedKeys.code}
+ ${convertedValues.code}
+ $output = new UnsafeMapData(${convertedKeys.primitive}, ${convertedValues.primitive});
+ }
+ }
+ """
+ GeneratedExpressionCode(code, outputIsNull, output)
+ }
+
+ /**
+ * Generates the java code to convert a data to its unsafe version.
+ */
+ private def createConvertCode(
+ ctx: CodeGenContext,
+ input: GeneratedExpressionCode,
+ dataType: DataType): GeneratedExpressionCode = dataType match {
+ case t: StructType =>
+ val output = ctx.freshName("convertedStruct")
+ val outputIsNull = ctx.freshName("isNull")
+ val tmp = ctx.freshName("tmp")
+ val fieldTypes = t.fields.map(_.dataType)
+ val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
+ val getFieldCode = ctx.getValue(tmp, dt, i.toString)
+ val fieldIsNull = s"$tmp.isNullAt($i)"
+ GeneratedExpressionCode("", fieldIsNull, getFieldCode)
+ }
+ val converter = createCodeForStruct2(ctx, fieldEvals, fieldTypes)
+ val code = s"""
+ ${input.code}
+ UnsafeRow $output = null;
+ final boolean $outputIsNull = ${input.isNull};
+ if (!$outputIsNull) {
+ final InternalRow $tmp = ${input.primitive};
+ if ($tmp instanceof UnsafeRow) {
+ $output = (UnsafeRow) $tmp;
+ } else {
+ ${converter.code}
+ $output = ${converter.primitive};
+ }
+ }
+ """
+ GeneratedExpressionCode(code, outputIsNull, output)
+
+ case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
+
+ case MapType(kt, vt, _) => createCodeForMap(ctx, input, kt, vt)
+
+ case _ => input
+ }
+
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)
@@ -259,10 +570,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
val ctx = newCodeGenContext()
- val isNull = ctx.freshName("retIsNull")
- val primitive = ctx.freshName("retValue")
- val eval = GeneratedExpressionCode("", isNull, primitive)
- eval.code = createCode(ctx, eval, expressions)
+ val exprEvals = expressions.map(e => e.gen(ctx))
+ val eval = createCodeForStruct2(ctx, exprEvals, expressions.map(_.dataType))
val code = s"""
public Object generate($exprType[] exprs) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
index db4876355dae..f6fa021adee9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -22,6 +22,9 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
override def numElements(): Int = keyArray.numElements()
+ override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())
+
+ // We need to check equality of map type in tests.
override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[ArrayBasedMapData]) {
return false
@@ -32,15 +35,15 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
return false
}
- this.keyArray == other.keyArray && this.valueArray == other.valueArray
+ ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other)
}
override def hashCode: Int = {
- keyArray.hashCode() * 37 + valueArray.hashCode()
+ ArrayBasedMapData.toScalaMap(this).hashCode()
}
override def toString(): String = {
- s"keys: $keyArray\nvalues: $valueArray"
+ s"keys: $keyArray, values: $valueArray"
}
}
@@ -48,4 +51,10 @@ object ArrayBasedMapData {
def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
}
+
+ def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = {
+ val keys = map.keyArray.asInstanceOf[GenericArrayData].array
+ val values = map.valueArray.asInstanceOf[GenericArrayData].array
+ keys.zip(values).toMap
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
index c99fc233255e..642c56f12ded 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -17,11 +17,15 @@
package org.apache.spark.sql.types
+import scala.reflect.ClassTag
+
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
abstract class ArrayData extends SpecializedGetters with Serializable {
def numElements(): Int
+ def copy(): ArrayData
+
def toBooleanArray(): Array[Boolean] = {
val size = numElements()
val values = new Array[Boolean](size)
@@ -99,19 +103,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
values
}
- def toArray[T](elementType: DataType): Array[T] = {
+ def toArray[T: ClassTag](elementType: DataType): Array[T] = {
val size = numElements()
- val values = new Array[Any](size)
+ val values = new Array[T](size)
var i = 0
while (i < size) {
if (isNullAt(i)) {
- values(i) = null
+ values(i) = null.asInstanceOf[T]
} else {
- values(i) = get(i, elementType)
+ values(i) = get(i, elementType).asInstanceOf[T]
}
i += 1
}
- values.asInstanceOf[Array[T]]
+ values
}
// todo: specialize this.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index b3e75f8bad50..b314acdfe364 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -17,13 +17,19 @@
package org.apache.spark.sql.types
+import scala.reflect.ClassTag
+
import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters
-class GenericArrayData(array: Array[Any]) extends ArrayData with GenericSpecializedGetters {
+class GenericArrayData(private[sql] val array: Array[Any])
+ extends ArrayData with GenericSpecializedGetters {
override def genericGet(ordinal: Int): Any = array(ordinal)
- override def toArray[T](elementType: DataType): Array[T] = array.asInstanceOf[Array[T]]
+ override def copy(): ArrayData = new GenericArrayData(array.clone())
+
+ // todo: Array is invariant in scala, maybe use toSeq instead?
+ override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T])
override def numElements(): Int = array.length
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
index 5514c3cd8546..f50969f0f0b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
@@ -25,6 +25,8 @@ abstract class MapData extends Serializable {
def valueArray(): ArrayData
+ def copy(): MapData
+
def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = {
val length = numElements()
val keys = keyArray()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 44f845620a10..59491c5ba160 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.unsafe.types.UTF8String
class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
+ private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
+
test("basic conversion with only primitive types") {
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
val converter = UnsafeProjection.create(fieldTypes)
@@ -73,8 +75,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val unsafeRow: UnsafeRow = converter.apply(row)
assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
+ roundedSize("Hello".getBytes.length) +
+ roundedSize("World".getBytes.length))
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
@@ -92,8 +94,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
val unsafeRow: UnsafeRow = converter.apply(row)
- assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
+ assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + roundedSize("Hello".getBytes.length))
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
@@ -172,6 +173,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r
}
+ // todo: we reuse the UnsafeRow in projection, so these tests are meaningless.
val setToNullAfterCreation = converter.apply(rowWithNoNullColumns)
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
@@ -235,4 +237,108 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val converter = UnsafeProjection.create(fieldTypes)
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
}
+
+ test("basic conversion with array type") {
+ val fieldTypes: Array[DataType] = Array(
+ ArrayType(LongType),
+ ArrayType(ArrayType(LongType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
+
+ val array1 = new GenericArrayData(Array[Any](1L, 2L))
+ val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L))))
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, array1)
+ row.update(1, array2)
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields() == 2)
+
+ val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
+ assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2)
+ assert(unsafeArray1.numElements() == 2)
+ assert(unsafeArray1.getLong(0) == 1L)
+ assert(unsafeArray1.getLong(1) == 2L)
+
+ val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData]
+ assert(unsafeArray2.numElements() == 1)
+
+ val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData]
+ assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2)
+ assert(nestedArray.numElements() == 2)
+ assert(nestedArray.getLong(0) == 3L)
+ assert(nestedArray.getLong(1) == 4L)
+
+ assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+
+ val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes)
+ val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes)
+ assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
+ }
+
+ test("basic conversion with map type") {
+ def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+
+ def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = {
+ val numElements = keys.length
+ assert(map.numElements() == numElements)
+
+ val keyArray = map.keys
+ assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements)
+ assert(keyArray.numElements() == numElements)
+ keys.zipWithIndex.foreach { case (key, i) =>
+ assert(keyArray.getInt(i) == key)
+ }
+
+ val valueArray = map.values
+ assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements)
+ assert(valueArray.numElements() == numElements)
+ values.zipWithIndex.foreach { case (value, i) =>
+ assert(valueArray.getLong(i) == value)
+ }
+
+ assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ }
+
+ val fieldTypes: Array[DataType] = Array(
+ MapType(IntegerType, LongType),
+ MapType(IntegerType, MapType(IntegerType, LongType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
+
+ val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L))
+
+ val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L))
+ val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap))
+
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, map1)
+ row.update(1, map2)
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields() == 2)
+
+ val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData]
+ testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L))
+
+ val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData]
+ assert(unsafeMap2.numElements() == 1)
+
+ val keyArray = unsafeMap2.keys
+ assert(keyArray.getSizeInBytes == 4 + 4)
+ assert(keyArray.numElements() == 1)
+ assert(keyArray.getInt(0) == 9)
+
+ val valueArray = unsafeMap2.values
+ assert(valueArray.numElements() == 1)
+ val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData]
+ testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L))
+ assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes)
+
+ assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+ val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes)
+ val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes)
+ assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+ }
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 916825d007cc..f6c9b87778f8 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -43,6 +43,9 @@ public final class UTF8String implements Comparable