From d5d3106a8dbb9ce6c8b30b18ea9fe6c0fbde7f62 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 14 May 2015 22:55:29 -0700 Subject: [PATCH 01/62] WIP towards external sorter for Spark SQL. This is based on an early version of my shuffle sort patch; the implementation will undergo significant refactoring based on improvements made as part of the shuffle patch. Stay tuned. --- .../DummySerializerInstance.java | 9 +- .../unsafe/UnsafeShuffleExternalSorter.java | 1 + .../unsafe/sort/PrefixComparator.java | 26 ++ .../unsafe/sort/RecordComparator.java | 37 +++ .../sort/RecordPointerAndKeyPrefix.java | 31 +++ .../unsafe/sort/UnsafeExternalSorter.java | 235 ++++++++++++++++++ .../unsafe/sort/UnsafeInMemorySorter.java | 175 +++++++++++++ .../unsafe/sort/UnsafeSortDataFormat.java | 80 ++++++ .../unsafe/sort/UnsafeSorterIterator.java | 35 +++ .../unsafe/sort/UnsafeSorterSpillMerger.java | 91 +++++++ .../unsafe/sort/UnsafeSorterSpillReader.java | 89 +++++++ .../unsafe/sort/UnsafeSorterSpillWriter.java | 94 +++++++ 12 files changed, 898 insertions(+), 5 deletions(-) rename core/src/main/java/org/apache/spark/{shuffle/unsafe => serializer}/DummySerializerInstance.java (91%) create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java similarity index 91% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java rename to core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java index 3f746b886bc9b..0399abc63c235 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.serializer; import java.io.IOException; import java.io.InputStream; @@ -24,9 +24,7 @@ import scala.reflect.ClassTag; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.PlatformDependent; /** @@ -35,7 +33,8 @@ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work * around this, we pass a dummy no-op serializer. */ -final class DummySerializerInstance extends SerializerInstance { +@Private +public final class DummySerializerInstance extends SerializerInstance { public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 9e9ed94b7890c..56289573209fb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -30,6 +30,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java new file mode 100644 index 0000000000000..c41332ad117cb --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. + */ +public abstract class PrefixComparator { + public abstract int compare(long prefix1, long prefix2); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java new file mode 100644 index 0000000000000..09e4258792204 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +/** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ +public abstract class RecordComparator { + + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ + public abstract int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset); +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java new file mode 100644 index 0000000000000..0c4ebde407cfc --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +final class RecordPointerAndKeyPrefix { + /** + * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * description of how these addresses are encoded. + */ + public long recordPointer; + + /** + * A key prefix, for use in comparisons. + */ + public long keyPrefix; +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..0fdcbf7bfb3c3 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; + +/** + * External sorter based on {@link UnsafeInMemorySorter}. + */ +public final class UnsafeExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private int numSpills = 0; + private UnsafeInMemorySorter sorter; + + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final LinkedList allocatedPages = new LinkedList(); + private final boolean spillingEnabled; + private final int fileBufferSize; + private ShuffleWriteMetrics writeMetrics; + + + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + + private final LinkedList spillWriters = + new LinkedList(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + openSorter(); + } + + // TODO: metrics tracking + integration with shuffle write metrics + + private void openSorter() throws IOException { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: connect write metrics to task metrics? + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire memory!"); + } + } + + this.sorter = + new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + @VisibleForTesting + public void spill() throws IOException { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); + spillWriters.add(spillWriter); + final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final Object baseObject = sortedRecords.getBaseObject(); + final long baseOffset = sortedRecords.getBaseOffset(); + // TODO: this assumption that the first long holds a length is not enforced via our interfaces + // We need to either always store this via the write path (e.g. not require the caller to do + // it), or provide interfaces / hooks for customizing the physical storage format etc. + final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); + } + spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes()); + numSpills++; + final long threadId = Thread.currentThread().getId(); + // TODO: messy; log _before_ spill + logger.info("Thread " + threadId + " spilling in-memory map of " + + org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" + + (numSpills + ((numSpills > 1) ? " times" : " time")) + " so far)"); + openSorter(); + } + + private long freeMemory() { + long memoryFreed = 0; + final Iterator iter = allocatedPages.iterator(); + while (iter.hasNext()) { + memoryManager.freePage(iter.next()); + shuffleMemoryManager.release(PAGE_SIZE); + memoryFreed += PAGE_SIZE; + iter.remove(); + } + currentPage = null; + currentPagePosition = -1; + return memoryFreed; + } + + private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord() && spillingEnabled) { + final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer); + if (memoryAcquired < memoryToGrowSortBuffer) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandSortBuffer(); + shuffleMemoryManager.release(oldSortBufferMemoryUsage); + } + } + + final long spaceInCurrentPage; + if (currentPage != null) { + spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); + } else { + spaceInCurrentPage = 0; + } + if (requiredSpace > PAGE_SIZE) { + // TODO: throw a more specific exception? + throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else if (requiredSpace > spaceInCurrentPage) { + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpill != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpill); + throw new Exception("Can't allocate memory!"); + } + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + logger.info("Acquired new page! " + allocatedPages.size() * PAGE_SIZE); + } + } + + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws Exception { + // Need 4 bytes to store the record length. + ensureSpaceInDataPage(lengthInBytes + 4); + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + + sorter.insertRecord(recordAddress, prefix); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + spillMerger.addSpill(sorter.getSortedIterator()); + return spillMerger.getSortedIterator(); + } +} \ No newline at end of file diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 0000000000000..0ce7989463a23 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r2.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer); + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer); + return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + } else { + return prefixComparisonResult; + } + } + } + + private final TaskMemoryManager memoryManager; + private final Sorter sorter; + private final Comparator sortComparator; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ + private long[] sortBuffer; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int sortBufferInsertPosition = 0; + + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + assert (initialSize > 0); + this.sortBuffer = new long[initialSize * 2]; + this.memoryManager = memoryManager; + this.sorter = new Sorter(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } + + public long getMemoryUsage() { + return sortBuffer.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return sortBufferInsertPosition + 2 < sortBuffer.length; + } + + public void expandSortBuffer() { + final long[] oldBuffer = sortBuffer; + sortBuffer = new long[oldBuffer.length * 2]; + System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); + } + + /** + * Insert a record into the sort buffer. + * + * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + */ + public void insertRecord(long objectAddress, long keyPrefix) { + if (!hasSpaceForAnotherRecord()) { + expandSortBuffer(); + } + sortBuffer[sortBufferInsertPosition] = objectAddress; + sortBufferInsertPosition++; + sortBuffer[sortBufferInsertPosition] = keyPrefix; + sortBufferInsertPosition++; + } + + private static final class SortedIterator extends UnsafeSorterIterator { + + private final TaskMemoryManager memoryManager; + private final int sortBufferInsertPosition; + private final long[] sortBuffer; + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + + SortedIterator( + TaskMemoryManager memoryManager, + int sortBufferInsertPosition, + long[] sortBuffer) { + this.memoryManager = memoryManager; + this.sortBufferInsertPosition = sortBufferInsertPosition; + this.sortBuffer = sortBuffer; + } + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void loadNext() { + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java new file mode 100644 index 0000000000000..d09c728a7a638 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.util.collection.SortDataFormat; + +/** + * Supports sorting an array of (record pointer, key prefix) pairs. + * Used in {@link UnsafeInMemorySorter}. + *

+ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + */ +final class UnsafeSortDataFormat extends SortDataFormat { + + public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + + private UnsafeSortDataFormat() { } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); + } + + @Override + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data[pos * 2]; + reuse.keyPrefix = data[pos * 2 + 1]; + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + long tempPointer = data[pos0 * 2]; + long tempKeyPrefix = data[pos0 * 2 + 1]; + data[pos0 * 2] = data[pos1 * 2]; + data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; + data[pos1 * 2] = tempPointer; + data[pos1 * 2 + 1] = tempKeyPrefix; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos * 2] = src[srcPos * 2]; + dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + } + + @Override + public long[] allocate(int length) { + assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; + return new long[length * 2]; + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java new file mode 100644 index 0000000000000..f63ded826afb4 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public abstract class UnsafeSorterIterator { + + public abstract boolean hasNext(); + + public abstract void loadNext() throws IOException; + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + + public abstract int getRecordLength(); + + public abstract long getKeyPrefix(); +} \ No newline at end of file diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java new file mode 100644 index 0000000000000..94ee6699ef101 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; +import java.util.Comparator; +import java.util.PriorityQueue; + +final class UnsafeSorterSpillMerger { + + private final PriorityQueue priorityQueue; + + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final PrefixComparator prefixComparator) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + // TODO: the size is often known; incorporate size hints here. + priorityQueue = new PriorityQueue(10, comparator); + } + + public void addSpill(UnsafeSorterIterator spillReader) throws IOException { + if (spillReader.hasNext()) { + spillReader.loadNext(); + } + priorityQueue.add(spillReader); + } + + public UnsafeSorterIterator getSortedIterator() throws IOException { + return new UnsafeSorterIterator() { + + private UnsafeSorterIterator spillReader; + + @Override + public boolean hasNext() { + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); + } + + @Override + public void loadNext() throws IOException { + if (spillReader != null) { + if (spillReader.hasNext()) { + spillReader.loadNext(); + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.remove(); + } + + @Override + public Object getBaseObject() { return spillReader.getBaseObject(); } + + @Override + public long getBaseOffset() { return spillReader.getBaseOffset(); } + + @Override + public int getRecordLength() { return spillReader.getRecordLength(); } + + @Override + public long getKeyPrefix() { return spillReader.getKeyPrefix(); } + }; + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..bee555ad9a909 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import com.google.common.io.ByteStreams; + +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; + +final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + + private final File file; + private InputStream in; + private DataInputStream din; + + private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? + private int nextRecordLength; + + private long keyPrefix; + private final Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + this.file = file; + assert (file.length() > 0); + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + nextRecordLength = din.readInt(); + } + + @Override + public boolean hasNext() { + return (in != null); + } + + @Override + public void loadNext() throws IOException { + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, nextRecordLength); + nextRecordLength = din.readInt(); + if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + in.close(); + in = null; + din = null; + } + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return 0; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } +} \ No newline at end of file diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..1b607b4db6921 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.*; + +import scala.Tuple2; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; + +final class UnsafeSorterSpillWriter { + + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + static final int EOF_MARKER = -1; + + private byte[] arr = new byte[SER_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + private BlockObjectWriter writer; + private DataOutputStream dos; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics) { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + dos = new DataOutputStream(writer); + } + + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + dos.writeInt(recordLength); + dos.writeLong(keyPrefix); + PlatformDependent.copyMemory( + baseObject, + baseOffset + 4, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); + writer.write(arr, 0, recordLength); + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten(); + } + + public void close() throws IOException { + dos.writeInt(EOF_MARKER); + writer.commitAndClose(); + writer = null; + dos = null; + arr = null; + } + + public long numberOfSpilledBytes() { + return file.length(); + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} From 2bd8c9a34285b3d5c603bc083dc82d878198c833 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 15 May 2015 14:11:32 -0700 Subject: [PATCH 02/62] Import my original tests and get them to pass. --- .../unsafe/sort/UnsafeSorterSpillWriter.java | 9 +- .../spark/storage/BlockObjectWriter.scala | 9 +- .../sort/UnsafeExternalSorterSuite.java | 183 ++++++++++++++++++ .../sort/UnsafeInMemorySorterSuite.java | 142 ++++++++++++++ 4 files changed, 337 insertions(+), 6 deletions(-) create mode 100644 core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java create mode 100644 core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 1b607b4db6921..33eb37fbeeba2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -59,10 +59,10 @@ public UnsafeSorterSpillWriter( } public void write( - Object baseObject, - long baseOffset, - int recordLength, - long keyPrefix) throws IOException { + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { dos.writeInt(recordLength); dos.writeLong(keyPrefix); PlatformDependent.copyMemory( @@ -72,7 +72,6 @@ public void write( PlatformDependent.BYTE_ARRAY_OFFSET, recordLength); writer.write(arr, 0, recordLength); - // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 7eeabd1e0489c..eef3c25a97b4e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -211,7 +211,14 @@ private[spark] class DiskBlockObjectWriter( recordWritten() } - override def write(b: Int): Unit = throw new UnsupportedOperationException() + override def write(b: Int): Unit = { + // TOOD: re-enable the `throw new UnsupportedOperationException()` here + if (!initialized) { + open() + } + + bs.write(b) + } override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 0000000000000..ea746e03471be --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.UUID; + +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeExternalSorterSuite { + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + + File tempDir; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[] { value }; + sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); + } + + /** + * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. + */ + @Test + public void testSortingOnlyByPartitionId() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + insertNumber(sorter, 2); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + iter.loadNext(); + assertEquals(1, iter.getKeyPrefix()); + iter.loadNext(); + assertEquals(2, iter.getKeyPrefix()); + iter.loadNext(); + assertEquals(3, iter.getKeyPrefix()); + iter.loadNext(); + assertEquals(4, iter.getKeyPrefix()); + iter.loadNext(); + assertEquals(5, iter.getKeyPrefix()); + assertFalse(iter.hasNext()); + // TODO: check that the values are also read back properly. + + // TODO: test for cleanup: + // assert(tempDir.isEmpty) + } + +} \ No newline at end of file diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java new file mode 100644 index 0000000000000..095d9ae5975d4 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Arrays; + +import org.junit.Test; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset) { + final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final byte[] strBytes = new byte[strLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + 8, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + mock(RecordComparator.class), + mock(PrefixComparator.class), + 100); + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + /** + * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. + */ + @Test + public void testSortingOnlyByPartitionId() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); + position += 8; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + prefixComparator, dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 8 + recordLength; + } + final UnsafeSorterIterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + iter.loadNext(); + final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset()); + final long keyPrefix = iter.getKeyPrefix(); + assertTrue(Arrays.asList(dataToSort).contains(str)); + assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); + prevPrefix = keyPrefix; + iterLength++; + } + assertEquals(dataToSort.length, iterLength); + } +} \ No newline at end of file From 58f36d01749d4eaa3c4ec6649b361ab415cae73c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 25 May 2015 20:47:07 -0700 Subject: [PATCH 03/62] Merge in a sketch of a unit test for the new sorter (now failing). --- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 4 +- .../execution/joins/UnsafeSortMergeJoin.scala | 190 ++++++++++++++++++ .../spark/sql/UnsafeSortMergeJoinSuite.scala | 52 +++++ .../execution/UnsafeExternalSortSuite.scala | 62 ++++++ .../UnsafeSortMergeCompatibiltySuite.scala | 41 ++++ 6 files changed, 348 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala create mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 0fdcbf7bfb3c3..96159ec1b1feb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -232,4 +232,4 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { spillMerger.addSpill(sorter.getSortedIterator()); return spillMerger.getSortedIterator(); } -} \ No newline at end of file +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 94ee6699ef101..2fb41fb2d402f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -26,8 +26,8 @@ final class UnsafeSorterSpillMerger { private final PriorityQueue priorityQueue; public UnsafeSorterSpillMerger( - final RecordComparator recordComparator, - final PrefixComparator prefixComparator) { + final RecordComparator recordComparator, + final PrefixComparator prefixComparator) { final Comparator comparator = new Comparator() { @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala new file mode 100644 index 0000000000000..73a9bc6a2cf60 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + * TODO(josh): Document + */ +@DeveloperApi +case class UnsafeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + protected override def doExecute(): RDD[Row] = { + // Note that we purposely do not require out input to be sorted. Instead, we'll sort it + // ourselves using UnsafeExternalSorter. Not requiring the input to be sorted will prevent the + // Exchange from pushing the sort into the shuffle, which will allow the shuffle to benefit from + // Project Tungsten's shuffle optimizations which currently cannot be applied to shuffles that + // specify a key ordering. + + // Only sort if necessary: + val leftOrder = requiredOrders(leftKeys) + val leftResults = { + if (left.outputOrdering == leftOrder) { + left.execute().map(_.copy()) + } else { + new UnsafeExternalSort(leftOrder, global = false, left).execute() + } + } + val rightOrder = requiredOrders(rightKeys) + val rightResults = { + if (right.outputOrdering == rightOrder) { + right.execute().map(_.copy()) + } else { + new UnsafeExternalSort(rightOrder, global = false, right).execute() + } + } + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + println(leftElement) + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + println(right) + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala new file mode 100644 index 0000000000000..4ce0537f02418 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.execution.UnsafeExternalSort +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.scalatest.BeforeAndAfterEach + +class UnsafeSortMergeJoinSuite extends QueryTest with BeforeAndAfterEach { + // Ensures tables are loaded. + TestData + + conf.setConf(SQLConf.SORTMERGE_JOIN, "true") + conf.setConf(SQLConf.CODEGEN_ENABLED, "true") + conf.setConf(SQLConf.UNSAFE_ENABLED, "true") + conf.setConf(SQLConf.EXTERNAL_SORT, "true") + conf.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, "-1") + + test("basic sort merge join test") { + val df = upperCaseData.join(lowerCaseData, $"n" === $"N") + print(df.queryExecution.optimizedPlan) + assert(df.queryExecution.sparkPlan.collect { + case smj: UnsafeSortMergeJoin => smj + }.nonEmpty) + checkAnswer( + df, + Seq( + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") + )) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala new file mode 100644 index 0000000000000..f5a6368a2b16a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +class UnsafeExternalSortSuite extends FunSuite with Matchers { + + private def createRow(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + + test("basic sorting") { + val sc = TestSQLContext.sparkContext + val sqlContext = new SQLContext(sc) + sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true") + + val schema: StructType = StructType( + StructField("word", StringType, nullable = false) :: + StructField("number", IntegerType, nullable = false) :: Nil) + val sortOrder: Seq[SortOrder] = Seq( + SortOrder(BoundReference(0, StringType, nullable = false), Ascending), + SortOrder(BoundReference(1, IntegerType, nullable = false), Descending)) + val rowsToSort: Seq[Row] = Seq( + createRow("Hello", 9), + createRow("World", 4), + createRow("Hello", 7), + createRow("Skinny", 0), + createRow("Constantinople", 9)) + SparkPlan.currentContext.set(sqlContext) + val input = + new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1)) + // Treat the existing sort operators as the source-of-truth for this test + val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect() + val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect() + val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect() + assert (defaultSorted === externalSorted) + assert (unsafeSorted === externalSorted) + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala new file mode 100644 index 0000000000000..093ce3504e2c2 --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join and + * unsafe external sort enabled. + */ +class UnsafeSortMergeCompatibiltySuite extends SortMergeCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.CODEGEN_ENABLED, "true") + TestHive.setConf(SQLConf.UNSAFE_ENABLED, "true") + TestHive.setConf(SQLConf.EXTERNAL_SORT, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.CODEGEN_ENABLED, "false") + TestHive.setConf(SQLConf.UNSAFE_ENABLED, "false") + TestHive.setConf(SQLConf.EXTERNAL_SORT, "false") + super.afterAll() + } +} \ No newline at end of file From dda6752873d162a183e3b216c1562f38b9e04fc7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 May 2015 14:04:03 -0700 Subject: [PATCH 04/62] Commit some missing code from an old git stash. --- .../unsafe/sort/UnsafeExternalSorter.java | 23 ++-- .../unsafe/sort/UnsafeInMemorySorter.java | 10 +- .../sql/catalyst/expressions/UnsafeRow.java | 31 +++++ .../apache/spark/sql/execution/Exchange.scala | 58 ++++++++- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../spark/sql/execution/basicOperators.scala | 119 ++++++++++++++++++ .../SortMergeCompatibilitySuite.scala | 1 + 7 files changed, 231 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 96159ec1b1feb..b341f1b4a989a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -137,7 +137,7 @@ public void spill() throws IOException { openSorter(); } - private long freeMemory() { + public long freeMemory() { long memoryFreed = 0; final Iterator iter = allocatedPages.iterator(); while (iter.hasNext()) { @@ -223,13 +223,20 @@ public void insertRecord( } public UnsafeSorterIterator getSortedIterator() throws IOException { - final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator); - for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpill(spillWriter.getReader(blockManager)); + final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + if (!spillWriters.isEmpty()) { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + if (inMemoryIterator.hasNext()) { + spillMerger.addSpill(inMemoryIterator); + } + return spillMerger.getSortedIterator(); + } else { + return inMemoryIterator; } - spillWriters.clear(); - spillMerger.addSpill(sorter.getSortedIterator()); - return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 0ce7989463a23..5ddabb63e649d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,6 +19,7 @@ import java.util.Comparator; +import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -50,10 +51,10 @@ private static final class SortComparator implements Comparator + false + } + logInfo(s"For row with data types ${withShuffle.schema.map(_.dataType)}, " + + s"supportsUnsafeRowConversion = $supportsUnsafeRowConversion") + if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { + logInfo("Using unsafe external sort!") + UnsafeExternalSort(rowOrdering, global = false, withShuffle) + } else if (sqlContext.conf.externalSortEnabled) { + logInfo("Not using unsafe sort") + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } } else { Sort(rowOrdering, global = false, withShuffle) } @@ -317,6 +352,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (meetsRequirements && compatible && !needsAnySort) { operator } else { + logInfo("Looking through Exchange") // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. @@ -334,7 +370,21 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (UnspecifiedDistribution, Seq(), child) => child case (UnspecifiedDistribution, rowOrdering, child) => - if (sqlContext.conf.externalSortEnabled) { + // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow + // supports the given schema. + val supportsUnsafeRowConversion: Boolean = try { + new UnsafeRowConverter(child.schema.map(_.dataType).toArray) + true + } catch { + case NonFatal(e) => + false + } + logInfo(s"For row with data types ${child.schema.map(_.dataType)}, " + + s"supportsUnsafeRowConversion = $supportsUnsafeRowConversion") + if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { + logInfo("Using unsafe external sort!") + UnsafeExternalSort(rowOrdering, global = false, child) + } else if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, child) } else { Sort(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d817586..a7f2b85576f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -74,7 +74,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they + * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { @@ -102,8 +102,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled => - val mergeJoin = + val mergeJoin = if (sqlContext.conf.unsafeEnabled) { + joins.UnsafeSortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + } else { joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + } condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 647c4ab5cb651..b714ae1e9a5cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,6 +17,12 @@ package org.apache.spark.sql.execution +import java.util.Arrays + +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.util.collection.unsafe.sort.{RecordComparator, PrefixComparator, UnsafeExternalSorter} +import org.apache.spark.{TaskContext, SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -245,6 +251,119 @@ case class ExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } +/** + * :: DeveloperApi :: + * TODO(josh): document + * Performs a sort, spilling to disk as needed. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +@DeveloperApi +case class UnsafeExternalSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + private[this] val numFields: Int = child.schema.size + private[this] val schema: StructType = child.schema + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + // TODO(josh): This code is unreadably messy; this should be split into a separate file + // and written in Java. + assert (codegenEnabled) + def doSort(iterator: Iterator[Row]): Iterator[Row] = { + val ordering = newOrdering(sortOrder, child.output) + val rowConverter = new UnsafeRowConverter(schema.map(_.dataType).toArray) + var rowConversionScratchSpace = new Array[Long](1024) + val prefixComparator = new PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = 0 + } + val recordComparator = new RecordComparator { + private[this] val row1 = new UnsafeRow + private[this] val row2 = new UnsafeRow + override def compare( + baseObj1: scala.Any, baseOff1: Long, baseObj2: scala.Any, baseOff2: Long): Int = { + row1.pointTo(baseObj1, baseOff1, numFields, schema) + row2.pointTo(baseObj2, baseOff2, numFields, schema) + ordering.compare(row1, row2) + } + } + val sorter = new UnsafeExternalSorter( + TaskContext.get.taskMemoryManager(), + SparkEnv.get.shuffleMemoryManager, + SparkEnv.get.blockManager, + TaskContext.get, + recordComparator, + prefixComparator, + 4096, + SparkEnv.get.conf + ) + while (iterator.hasNext) { + val row: Row = iterator.next() + val sizeRequirement = rowConverter.getSizeRequirement(row) + if (sizeRequirement / 8 > rowConversionScratchSpace.length) { + rowConversionScratchSpace = new Array[Long](sizeRequirement / 8) + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero + // out the portion of the buffer that we'll actually write to. + Arrays.fill(rowConversionScratchSpace, 0, sizeRequirement / 8, 0) + } + val bytesWritten = + rowConverter.writeRow(row, rowConversionScratchSpace, PlatformDependent.LONG_ARRAY_OFFSET) + assert (bytesWritten == sizeRequirement) + val prefix: Long = 0 // dummy prefix until we implement prefix calculation + sorter.insertRecord( + rowConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequirement, + prefix + ) + } + val sortedIterator = sorter.getSortedIterator + // TODO: need to avoid memory leaks on exceptions, etc. by wrapping in resource cleanup blocks + // TODO: need to clean up spill files after success or failure. + new Iterator[Row] { + private[this] val row = new UnsafeRow() + override def hasNext: Boolean = sortedIterator.hasNext + + override def next(): Row = { + sortedIterator.loadNext() + if (hasNext) { + row.pointTo( + sortedIterator.getBaseObject, sortedIterator.getBaseOffset, numFields, schema) + println("Returned row " + row) + row + } else { + val rowDataCopy = new Array[Byte](sortedIterator.getRecordLength) + PlatformDependent.copyMemory( + sortedIterator.getBaseObject, + sortedIterator.getBaseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sortedIterator.getRecordLength + ) + row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema) + sorter.freeMemory() + row + } + } + } + } + child.execute().mapPartitions(doSort, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + + /** * :: DeveloperApi :: * Return a new RDD that has exactly `numPartitions` partitions. diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f458567e5d7ea..7129007cdfbc2 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.TestSQLContext._ /** * Runs the test cases that are included in the hive distribution with sort merge join is true. From c8792dee73a15bf9734148dbfa62ac64d6459b49 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 May 2015 14:11:57 -0700 Subject: [PATCH 05/62] Remove some debug logging --- .../org/apache/spark/sql/execution/Exchange.scala | 15 +++------------ .../spark/sql/execution/basicOperators.scala | 1 - 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 42a9111b7e70e..8372bd0810234 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import scala.util.control.NonFatal + +import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer @@ -32,8 +35,6 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} -import scala.util.control.NonFatal - object Exchange { /** * Returns true when the ordering expressions are a subset of the key. @@ -194,7 +195,6 @@ case class Exchange( } val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) if (newOrdering.nonEmpty) { - println("Shuffling with a key ordering") shuffled.setKeyOrdering(keyOrdering) } shuffled.setSerializer(serializer) @@ -308,7 +308,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ partitioning: Partitioning, rowOrdering: Seq[SortOrder], child: SparkPlan): SparkPlan = { - logInfo("In addOperatorsIfNecessary") val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering val needsShuffle = child.outputPartitioning != partitioning @@ -328,13 +327,9 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case NonFatal(e) => false } - logInfo(s"For row with data types ${withShuffle.schema.map(_.dataType)}, " + - s"supportsUnsafeRowConversion = $supportsUnsafeRowConversion") if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { - logInfo("Using unsafe external sort!") UnsafeExternalSort(rowOrdering, global = false, withShuffle) } else if (sqlContext.conf.externalSortEnabled) { - logInfo("Not using unsafe sort") ExternalSort(rowOrdering, global = false, withShuffle) } else { Sort(rowOrdering, global = false, withShuffle) @@ -352,7 +347,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (meetsRequirements && compatible && !needsAnySort) { operator } else { - logInfo("Looking through Exchange") // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. @@ -379,10 +373,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case NonFatal(e) => false } - logInfo(s"For row with data types ${child.schema.map(_.dataType)}, " + - s"supportsUnsafeRowConversion = $supportsUnsafeRowConversion") if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { - logInfo("Using unsafe external sort!") UnsafeExternalSort(rowOrdering, global = false, child) } else if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b714ae1e9a5cb..3cdaac4ea5c18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -337,7 +337,6 @@ case class UnsafeExternalSort( if (hasNext) { row.pointTo( sortedIterator.getBaseObject, sortedIterator.getBaseOffset, numFields, schema) - println("Returned row " + row) row } else { val rowDataCopy = new Array[Byte](sortedIterator.getRecordLength) From 9cc98f5f3e746ade172b4dad3b4ba7401426b5c6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 27 May 2015 16:03:15 -0700 Subject: [PATCH 06/62] Move more code to Java; fix bugs in UnsafeRowConverter length type. The length type is an int, not long, but the code was inconsistent about this. I also now use byte arrays instead of long arrays in some places in order to avoid off-by-factor-of-8 errors. --- .../unsafe/sort/UnsafeExternalSorter.java | 8 +- .../sql/catalyst/expressions/UnsafeRow.java | 7 + .../execution/UnsafeExternalRowSorter.java | 171 ++++++++++++++++++ .../spark/sql/execution/basicOperators.scala | 84 +-------- 4 files changed, 186 insertions(+), 84 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index b341f1b4a989a..d0ede69e9e66c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -151,7 +151,7 @@ public long freeMemory() { return memoryFreed; } - private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + private void ensureSpaceInDataPage(int requiredSpace) throws IOException { // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. @@ -176,7 +176,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { } if (requiredSpace > PAGE_SIZE) { // TODO: throw a more specific exception? - throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + PAGE_SIZE + ")"); } else if (requiredSpace > spaceInCurrentPage) { if (spillingEnabled) { @@ -187,7 +187,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); if (memoryAcquiredAfterSpill != PAGE_SIZE) { shuffleMemoryManager.release(memoryAcquiredAfterSpill); - throw new Exception("Can't allocate memory!"); + throw new IOException("Can't allocate memory!"); } } } @@ -202,7 +202,7 @@ public void insertRecord( Object recordBaseObject, long recordBaseOffset, int lengthInBytes, - long prefix) throws Exception { + long prefix) throws IOException { // Need 4 bytes to store the record length. ensureSpaceInDataPage(lengthInBytes + 4); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5a974d7faec55..94c038c74eb3a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,6 +17,11 @@ package org.apache.spark.sql.catalyst.expressions; +import java.math.BigDecimal; +import java.sql.Date; +import java.util.*; +import javax.annotation.Nullable; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; @@ -55,6 +60,8 @@ */ public final class UnsafeRow extends MutableRow { + /** Hack for if we want to pass around an UnsafeRow which also carries around its backing data */ + @Nullable public byte[] backingArray; private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java new file mode 100644 index 0000000000000..137e473e34ac2 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.Arrays; + +import scala.Function1; +import scala.collection.AbstractIterator; +import scala.collection.Iterator; +import scala.math.Ordering; + +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; + +final class UnsafeExternalRowSorter { + + private final StructType schema; + private final UnsafeRowConverter rowConverter; + private final RowComparator rowComparator; + private final PrefixComparator prefixComparator; + private final Function1 prefixComputer; + + public UnsafeExternalRowSorter( + StructType schema, + Ordering ordering, + PrefixComparator prefixComparator, + // TODO: if possible, avoid this boxing of the return value + Function1 prefixComputer) { + this.schema = schema; + this.rowConverter = new UnsafeRowConverter(schema); + this.rowComparator = new RowComparator(ordering, schema); + this.prefixComparator = prefixComparator; + this.prefixComputer = prefixComputer; + } + + public Iterator sort(Iterator inputIterator) throws IOException { + final SparkEnv sparkEnv = SparkEnv.get(); + final TaskContext taskContext = TaskContext.get(); + byte[] rowConversionBuffer = new byte[1024 * 8]; + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + taskContext.taskMemoryManager(), + sparkEnv.shuffleMemoryManager(), + sparkEnv.blockManager(), + taskContext, + rowComparator, + prefixComparator, + 4096, + sparkEnv.conf() + ); + try { + while (inputIterator.hasNext()) { + final Row row = inputIterator.next(); + final int sizeRequirement = rowConverter.getSizeRequirement(row); + if (sizeRequirement > rowConversionBuffer.length) { + rowConversionBuffer = new byte[sizeRequirement]; + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero + // out the portion of the buffer that we'll actually write to. + Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0); + } + final int bytesWritten = + rowConverter.writeRow(row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET); + assert (bytesWritten == sizeRequirement); + final long prefix = prefixComputer.apply(row); + sorter.insertRecord( + rowConversionBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeRequirement, + prefix + ); + } + final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); + return new AbstractIterator() { + + private final int numFields = schema.length(); + private final UnsafeRow row = new UnsafeRow(); + + @Override + public boolean hasNext() { + return sortedIterator.hasNext(); + } + + @Override + public Row next() { + try { + sortedIterator.loadNext(); + if (hasNext()) { + row.pointTo( + sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, schema); + return row; + } else { + final byte[] rowDataCopy = new byte[sortedIterator.getRecordLength()]; + PlatformDependent.copyMemory( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sortedIterator.getRecordLength() + ); + row.backingArray = rowDataCopy; + row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema); + sorter.freeMemory(); + return row; + } + } catch (IOException e) { + // TODO: we need to ensure that files are cleaned properly after an exception, + // so we need better cleanup methods than freeMemory(). + sorter.freeMemory(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + PlatformDependent.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + }; + }; + } catch (IOException e) { + // TODO: we need to ensure that files are cleaned properly after an exception, + // so we need better cleanup methods than freeMemory(). + sorter.freeMemory(); + throw e; + } + } + + private static final class RowComparator extends RecordComparator { + private final StructType schema; + private final Ordering ordering; + private final int numFields; + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + + public RowComparator(Ordering ordering, StructType schema) { + this.schema = schema; + this.numFields = schema.length(); + this.ordering = ordering; + } + + @Override + public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + row1.pointTo(baseObj1, baseOff1, numFields, schema); + row2.pointTo(baseObj2, baseOff2, numFields, schema); + return ordering.compare(row1, row2); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3cdaac4ea5c18..22a5688c1e8b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,12 +17,9 @@ package org.apache.spark.sql.execution -import java.util.Arrays - import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.util.collection.unsafe.sort.{RecordComparator, PrefixComparator, UnsafeExternalSorter} -import org.apache.spark.{TaskContext, SparkEnv, HashPartitioner, SparkConf} +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator +import org.apache.spark.{SparkEnv, HashPartitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -272,87 +269,14 @@ case class UnsafeExternalSort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { - // TODO(josh): This code is unreadably messy; this should be split into a separate file - // and written in Java. assert (codegenEnabled) def doSort(iterator: Iterator[Row]): Iterator[Row] = { val ordering = newOrdering(sortOrder, child.output) - val rowConverter = new UnsafeRowConverter(schema.map(_.dataType).toArray) - var rowConversionScratchSpace = new Array[Long](1024) val prefixComparator = new PrefixComparator { override def compare(prefix1: Long, prefix2: Long): Int = 0 } - val recordComparator = new RecordComparator { - private[this] val row1 = new UnsafeRow - private[this] val row2 = new UnsafeRow - override def compare( - baseObj1: scala.Any, baseOff1: Long, baseObj2: scala.Any, baseOff2: Long): Int = { - row1.pointTo(baseObj1, baseOff1, numFields, schema) - row2.pointTo(baseObj2, baseOff2, numFields, schema) - ordering.compare(row1, row2) - } - } - val sorter = new UnsafeExternalSorter( - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, - SparkEnv.get.blockManager, - TaskContext.get, - recordComparator, - prefixComparator, - 4096, - SparkEnv.get.conf - ) - while (iterator.hasNext) { - val row: Row = iterator.next() - val sizeRequirement = rowConverter.getSizeRequirement(row) - if (sizeRequirement / 8 > rowConversionScratchSpace.length) { - rowConversionScratchSpace = new Array[Long](sizeRequirement / 8) - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero - // out the portion of the buffer that we'll actually write to. - Arrays.fill(rowConversionScratchSpace, 0, sizeRequirement / 8, 0) - } - val bytesWritten = - rowConverter.writeRow(row, rowConversionScratchSpace, PlatformDependent.LONG_ARRAY_OFFSET) - assert (bytesWritten == sizeRequirement) - val prefix: Long = 0 // dummy prefix until we implement prefix calculation - sorter.insertRecord( - rowConversionScratchSpace, - PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequirement, - prefix - ) - } - val sortedIterator = sorter.getSortedIterator - // TODO: need to avoid memory leaks on exceptions, etc. by wrapping in resource cleanup blocks - // TODO: need to clean up spill files after success or failure. - new Iterator[Row] { - private[this] val row = new UnsafeRow() - override def hasNext: Boolean = sortedIterator.hasNext - - override def next(): Row = { - sortedIterator.loadNext() - if (hasNext) { - row.pointTo( - sortedIterator.getBaseObject, sortedIterator.getBaseOffset, numFields, schema) - row - } else { - val rowDataCopy = new Array[Byte](sortedIterator.getRecordLength) - PlatformDependent.copyMemory( - sortedIterator.getBaseObject, - sortedIterator.getBaseOffset, - rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sortedIterator.getRecordLength - ) - row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema) - sorter.freeMemory() - row - } - } - } + def prefixComputer(row: Row): Long = 0 + new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) } child.execute().mapPartitions(doSort, preservesPartitioning = true) } From 73cc761247283ed67e5f80b9f32fc9dd68f3f9d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 5 Jun 2015 15:58:24 -0700 Subject: [PATCH 07/62] Fix whitespace --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a7f2b85576f54..9f10bf05b6bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -74,7 +74,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be * ''broadcasted'' to all of the executors involved in the join, as a - * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they + * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { From dfdb93f524afcbc8c00895b827fc9f5ed890bcbc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 5 Jun 2015 16:06:12 -0700 Subject: [PATCH 08/62] SparkFunSuite change --- .../spark/sql/execution/UnsafeExternalSortSuite.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index f5a6368a2b16a..d418638c30eb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ -class UnsafeExternalSortSuite extends FunSuite with Matchers { +class UnsafeExternalSortSuite extends SparkFunSuite with Matchers { private def createRow(values: Any*): Row = { new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) From b420a7100096bc1a19931f42878b41403e213756 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 Jun 2015 15:29:38 -0700 Subject: [PATCH 09/62] Move most of the existing SMJ code into Java. --- .../execution/UnsafeExternalRowSorter.java | 4 +- .../joins/SortMergeJoinIterator.java | 153 ++++++++++++++++++ .../spark/sql/AbstractScalaRowIterator.scala | 25 +++ .../sql/execution/joins/SortMergeJoin.scala | 108 +------------ .../execution/joins/UnsafeSortMergeJoin.scala | 113 +------------ 5 files changed, 193 insertions(+), 210 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 137e473e34ac2..2380c4614b5e3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -21,12 +21,12 @@ import java.util.Arrays; import scala.Function1; -import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; @@ -97,7 +97,7 @@ public Iterator sort(Iterator inputIterator) throws IOException { ); } final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); - return new AbstractIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private final UnsafeRow row = new UnsafeRow(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java new file mode 100644 index 0000000000000..b046cd5980e0c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins; + +import java.util.NoSuchElementException; +import javax.annotation.Nullable; + +import scala.Function1; +import scala.collection.Iterator; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import org.apache.spark.sql.AbstractScalaRowIterator; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.JoinedRow5; +import org.apache.spark.sql.catalyst.expressions.RowOrdering; +import org.apache.spark.util.collection.CompactBuffer; + +/** + * Implements the merge step of sort-merge join. + */ +class SortMergeJoinIterator extends AbstractScalaRowIterator { + + private static final ClassTag ROW_CLASS_TAG = ClassTag$.MODULE$.apply(Row.class); + private final Iterator leftIter; + private final Iterator rightIter; + private final Function1 leftKeyGenerator; + private final Function1 rightKeyGenerator; + private final RowOrdering keyOrdering; + private final JoinedRow5 joinRow = new JoinedRow5(); + + @Nullable private Row leftElement; + @Nullable private Row rightElement; + private Row leftKey; + private Row rightKey; + @Nullable private CompactBuffer rightMatches; + private int rightPosition = -1; + private boolean stop = false; + private Row matchKey; + + public SortMergeJoinIterator( + Iterator leftIter, + Iterator rightIter, + Function1 leftKeyGenerator, + Function1 rightKeyGenerator, + RowOrdering keyOrdering) { + this.leftIter = leftIter; + this.rightIter = rightIter; + this.leftKeyGenerator = leftKeyGenerator; + this.rightKeyGenerator = rightKeyGenerator; + this.keyOrdering = keyOrdering; + fetchLeft(); + fetchRight(); + } + + private void fetchLeft() { + if (leftIter.hasNext()) { + leftElement = leftIter.next(); + leftKey = leftKeyGenerator.apply(leftElement); + } else { + leftElement = null; + } + } + + private void fetchRight() { + if (rightIter.hasNext()) { + rightElement = rightIter.next(); + rightKey = rightKeyGenerator.apply(rightElement); + } else { + rightElement = null; + } + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private boolean nextMatchingPair() { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + final int comparing = keyOrdering.compare(leftKey, rightKey); + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull(); + if (comparing > 0 || rightKey.anyNull()) { + fetchRight(); + } else if (comparing < 0 || leftKey.anyNull()) { + fetchLeft(); + } + } + rightMatches = new CompactBuffer(ROW_CLASS_TAG); + if (stop) { + stop = false; + // Iterate the right side to buffer all rows that match. + // As the records should be ordered, exit when we meet the first record that not match. + while (!stop && rightElement != null) { + rightMatches.$plus$eq(rightElement); + fetchRight(); + stop = keyOrdering.compare(leftKey, rightKey) != 0; + } + if (rightMatches.size() > 0) { + rightPosition = 0; + matchKey = leftKey; + } + } + } + return rightMatches != null && rightMatches.size() > 0; + } + + @Override + public boolean hasNext() { + return nextMatchingPair(); + } + + @Override + public Row next() { + if (hasNext()) { + // We are using the buffered right rows and run down left iterator + final Row joinedRow = joinRow.apply(leftElement, rightMatches.apply(rightPosition)); + rightPosition += 1; + if (rightPosition >= rightMatches.size()) { + rightPosition = 0; + fetchLeft(); + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false; + rightMatches = null; + } + } + return joinedRow; + } else { + // No more results + throw new NoSuchElementException(); + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala new file mode 100644 index 0000000000000..38d0b6ad25c9a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator + * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to + * `Row` in order to work around a spurious IntelliJ compiler error. + */ +private[spark] abstract class AbstractScalaRowIterator extends Iterator[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 2abe65a71813d..ea4ae490f4db5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: @@ -63,105 +60,12 @@ case class SortMergeJoin( val rightResults = right.execute().map(_.copy()) leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[InternalRow] { - // Mutable per row objects. - private[this] val joinRow = new JoinedRow5 - private[this] var leftElement: InternalRow = _ - private[this] var rightElement: InternalRow = _ - private[this] var leftKey: InternalRow = _ - private[this] var rightKey: InternalRow = _ - private[this] var rightMatches: CompactBuffer[InternalRow] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: InternalRow = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): InternalRow = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow - } else { - // no more result - throw new NoSuchElementException - } - } - - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - leftKey = leftKeyGenerator(leftElement) - } else { - leftElement = null - } - } - - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - rightKey = rightKeyGenerator(rightElement) - } else { - rightElement = null - } - } - - private def initialize() = { - fetchLeft() - fetchRight() - } - - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[InternalRow]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } - } - rightMatches != null && rightMatches.size > 0 - } - } + new SortMergeJoinIterator( + leftIter, + rightIter, + leftKeyGenerator, + rightKeyGenerator, + keyOrdering); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala index 73a9bc6a2cf60..2b2c6e6dbc500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala @@ -17,21 +17,17 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan} -import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: - * Performs an sort merge join of two child relations. - * TODO(josh): Document + * Optimized version of [[SortMergeJoin]], implemented as part of Project Tungsten. */ @DeveloperApi case class UnsafeSortMergeJoin( @@ -84,107 +80,12 @@ case class UnsafeSortMergeJoin( } leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[Row] { - // Mutable per row objects. - private[this] val joinRow = new JoinedRow5 - private[this] var leftElement: Row = _ - private[this] var rightElement: Row = _ - private[this] var leftKey: Row = _ - private[this] var rightKey: Row = _ - private[this] var rightMatches: CompactBuffer[Row] = _ - private[this] var rightPosition: Int = -1 - private[this] var stop: Boolean = false - private[this] var matchKey: Row = _ - - // initialize iterator - initialize() - - override final def hasNext: Boolean = nextMatchingPair() - - override final def next(): Row = { - if (hasNext) { - // we are using the buffered right rows and run down left iterator - val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) - rightPosition += 1 - if (rightPosition >= rightMatches.size) { - rightPosition = 0 - fetchLeft() - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false - rightMatches = null - } - } - joinedRow - } else { - // no more result - throw new NoSuchElementException - } - } - - private def fetchLeft() = { - if (leftIter.hasNext) { - leftElement = leftIter.next() - println(leftElement) - leftKey = leftKeyGenerator(leftElement) - } else { - leftElement = null - } - } - - private def fetchRight() = { - if (rightIter.hasNext) { - rightElement = rightIter.next() - println(right) - rightKey = rightKeyGenerator(rightElement) - } else { - rightElement = null - } - } - - private def initialize() = { - fetchLeft() - fetchRight() - } - - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private def nextMatchingPair(): Boolean = { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - val comparing = keyOrdering.compare(leftKey, rightKey) - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull - if (comparing > 0 || rightKey.anyNull) { - fetchRight() - } else if (comparing < 0 || leftKey.anyNull) { - fetchLeft() - } - } - rightMatches = new CompactBuffer[Row]() - if (stop) { - stop = false - // iterate the right side to buffer all rows that matches - // as the records should be ordered, exit when we meet the first that not match - while (!stop && rightElement != null) { - rightMatches += rightElement - fetchRight() - stop = keyOrdering.compare(leftKey, rightKey) != 0 - } - if (rightMatches.size > 0) { - rightPosition = 0 - matchKey = leftKey - } - } - } - rightMatches != null && rightMatches.size > 0 - } - } + new SortMergeJoinIterator( + leftIter, + rightIter, + leftKeyGenerator, + rightKeyGenerator, + keyOrdering); } } } From 1b841ca95836c96bf10a04e16b145d5bf65260d6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 Jun 2015 16:22:06 -0700 Subject: [PATCH 10/62] WIP towards copying --- .../sql/execution/joins/SortMergeJoinIterator.java | 2 +- .../apache/spark/sql/execution/basicOperators.scala | 7 ++++--- .../sql/execution/joins/UnsafeSortMergeJoin.scala | 10 +++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java index b046cd5980e0c..888c74d72b002 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java @@ -112,7 +112,7 @@ private boolean nextMatchingPair() { // Iterate the right side to buffer all rows that match. // As the records should be ordered, exit when we meet the first record that not match. while (!stop && rightElement != null) { - rightMatches.$plus$eq(rightElement); + rightMatches.$plus$eq(rightElement.copy()); fetchRight(); stop = keyOrdering.compare(leftKey, rightKey) != 0; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 22a5688c1e8b7..2ae683d601ede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -250,8 +250,9 @@ case class ExternalSort( /** * :: DeveloperApi :: - * TODO(josh): document - * Performs a sort, spilling to disk as needed. + * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Project Tungsten). + * * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. */ @@ -262,7 +263,6 @@ case class UnsafeExternalSort( child: SparkPlan) extends UnaryNode { - private[this] val numFields: Int = child.schema.size private[this] val schema: StructType = child.schema override def requiredChildDistribution: Seq[Distribution] = @@ -275,6 +275,7 @@ case class UnsafeExternalSort( val prefixComparator = new PrefixComparator { override def compare(prefix1: Long, prefix2: Long): Int = 0 } + // TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation. def prefixComputer(row: Row): Long = 0 new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala index 2b2c6e6dbc500..d836c4b4faf2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution.joins - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan} @@ -63,17 +63,17 @@ case class UnsafeSortMergeJoin( // Only sort if necessary: val leftOrder = requiredOrders(leftKeys) - val leftResults = { + val leftResults: RDD[Row] = { if (left.outputOrdering == leftOrder) { - left.execute().map(_.copy()) + left.execute() } else { new UnsafeExternalSort(leftOrder, global = false, left).execute() } } val rightOrder = requiredOrders(rightKeys) - val rightResults = { + val rightResults: RDD[Row] = { if (right.outputOrdering == rightOrder) { - right.execute().map(_.copy()) + right.execute() } else { new UnsafeExternalSort(rightOrder, global = false, right).execute() } From 269cf864907fa8c9a76f89b08477dc926a94e548 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 12 Jun 2015 19:11:42 -0700 Subject: [PATCH 11/62] Back out SMJ operator change; isolate changes to selection of sort op. I'll consider fusing the sort and merge steps in a followup PR. --- .../joins/SortMergeJoinIterator.java | 153 ------------------ .../spark/sql/execution/SparkStrategies.scala | 5 +- .../sql/execution/joins/SortMergeJoin.scala | 109 ++++++++++++- .../execution/joins/UnsafeSortMergeJoin.scala | 91 ----------- 4 files changed, 104 insertions(+), 254 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java deleted file mode 100644 index 888c74d72b002..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/joins/SortMergeJoinIterator.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins; - -import java.util.NoSuchElementException; -import javax.annotation.Nullable; - -import scala.Function1; -import scala.collection.Iterator; -import scala.reflect.ClassTag; -import scala.reflect.ClassTag$; - -import org.apache.spark.sql.AbstractScalaRowIterator; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.JoinedRow5; -import org.apache.spark.sql.catalyst.expressions.RowOrdering; -import org.apache.spark.util.collection.CompactBuffer; - -/** - * Implements the merge step of sort-merge join. - */ -class SortMergeJoinIterator extends AbstractScalaRowIterator { - - private static final ClassTag ROW_CLASS_TAG = ClassTag$.MODULE$.apply(Row.class); - private final Iterator leftIter; - private final Iterator rightIter; - private final Function1 leftKeyGenerator; - private final Function1 rightKeyGenerator; - private final RowOrdering keyOrdering; - private final JoinedRow5 joinRow = new JoinedRow5(); - - @Nullable private Row leftElement; - @Nullable private Row rightElement; - private Row leftKey; - private Row rightKey; - @Nullable private CompactBuffer rightMatches; - private int rightPosition = -1; - private boolean stop = false; - private Row matchKey; - - public SortMergeJoinIterator( - Iterator leftIter, - Iterator rightIter, - Function1 leftKeyGenerator, - Function1 rightKeyGenerator, - RowOrdering keyOrdering) { - this.leftIter = leftIter; - this.rightIter = rightIter; - this.leftKeyGenerator = leftKeyGenerator; - this.rightKeyGenerator = rightKeyGenerator; - this.keyOrdering = keyOrdering; - fetchLeft(); - fetchRight(); - } - - private void fetchLeft() { - if (leftIter.hasNext()) { - leftElement = leftIter.next(); - leftKey = leftKeyGenerator.apply(leftElement); - } else { - leftElement = null; - } - } - - private void fetchRight() { - if (rightIter.hasNext()) { - rightElement = rightIter.next(); - rightKey = rightKeyGenerator.apply(rightElement); - } else { - rightElement = null; - } - } - - /** - * Searches the right iterator for the next rows that have matches in left side, and store - * them in a buffer. - * - * @return true if the search is successful, and false if the right iterator runs out of - * tuples. - */ - private boolean nextMatchingPair() { - if (!stop && rightElement != null) { - // run both side to get the first match pair - while (!stop && leftElement != null && rightElement != null) { - final int comparing = keyOrdering.compare(leftKey, rightKey); - // for inner join, we need to filter those null keys - stop = comparing == 0 && !leftKey.anyNull(); - if (comparing > 0 || rightKey.anyNull()) { - fetchRight(); - } else if (comparing < 0 || leftKey.anyNull()) { - fetchLeft(); - } - } - rightMatches = new CompactBuffer(ROW_CLASS_TAG); - if (stop) { - stop = false; - // Iterate the right side to buffer all rows that match. - // As the records should be ordered, exit when we meet the first record that not match. - while (!stop && rightElement != null) { - rightMatches.$plus$eq(rightElement.copy()); - fetchRight(); - stop = keyOrdering.compare(leftKey, rightKey) != 0; - } - if (rightMatches.size() > 0) { - rightPosition = 0; - matchKey = leftKey; - } - } - } - return rightMatches != null && rightMatches.size() > 0; - } - - @Override - public boolean hasNext() { - return nextMatchingPair(); - } - - @Override - public Row next() { - if (hasNext()) { - // We are using the buffered right rows and run down left iterator - final Row joinedRow = joinRow.apply(leftElement, rightMatches.apply(rightPosition)); - rightPosition += 1; - if (rightPosition >= rightMatches.size()) { - rightPosition = 0; - fetchLeft(); - if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { - stop = false; - rightMatches = null; - } - } - return joinedRow; - } else { - // No more results - throw new NoSuchElementException(); - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9f10bf05b6bdb..5daf86d817586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -102,11 +102,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if sqlContext.conf.sortMergeJoinEnabled => - val mergeJoin = if (sqlContext.conf.unsafeEnabled) { - joins.UnsafeSortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - } else { + val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - } condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index ea4ae490f4db5..0699650102a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution.joins +import java.util.NoSuchElementException + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer /** * :: DeveloperApi :: @@ -60,12 +64,105 @@ case class SortMergeJoin( val rightResults = right.execute().map(_.copy()) leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new SortMergeJoinIterator( - leftIter, - rightIter, - leftKeyGenerator, - rightKeyGenerator, - keyOrdering); + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala deleted file mode 100644 index d836c4b4faf2f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan} - -/** - * :: DeveloperApi :: - * Optimized version of [[SortMergeJoin]], implemented as part of Project Tungsten. - */ -@DeveloperApi -case class UnsafeSortMergeJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - - override def output: Seq[Attribute] = left.output ++ right.output - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - // this is to manually construct an ordering that can be used to compare keys from both sides - private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) - - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) - - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = - keys.map(SortOrder(_, Ascending)) - - protected override def doExecute(): RDD[Row] = { - // Note that we purposely do not require out input to be sorted. Instead, we'll sort it - // ourselves using UnsafeExternalSorter. Not requiring the input to be sorted will prevent the - // Exchange from pushing the sort into the shuffle, which will allow the shuffle to benefit from - // Project Tungsten's shuffle optimizations which currently cannot be applied to shuffles that - // specify a key ordering. - - // Only sort if necessary: - val leftOrder = requiredOrders(leftKeys) - val leftResults: RDD[Row] = { - if (left.outputOrdering == leftOrder) { - left.execute() - } else { - new UnsafeExternalSort(leftOrder, global = false, left).execute() - } - } - val rightOrder = requiredOrders(rightKeys) - val rightResults: RDD[Row] = { - if (right.outputOrdering == rightOrder) { - right.execute() - } else { - new UnsafeExternalSort(rightOrder, global = false, right).execute() - } - } - - leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new SortMergeJoinIterator( - leftIter, - rightIter, - leftKeyGenerator, - rightKeyGenerator, - keyOrdering); - } - } -} From d468a889495b5f74285aa7cec63ef82b4888cc8f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 13 Jun 2015 18:00:35 -0700 Subject: [PATCH 12/62] Update for InternalRow refactoring --- .../execution/UnsafeExternalRowSorter.java | 18 +++--- .../spark/sql/AbstractScalaRowIterator.scala | 4 +- .../apache/spark/sql/execution/Exchange.scala | 58 +++++-------------- .../spark/sql/execution/basicOperators.scala | 6 +- .../sql/execution/joins/SortMergeJoin.scala | 19 +++--- 5 files changed, 39 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 2380c4614b5e3..7c53ea7bdac24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -27,7 +27,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; import org.apache.spark.sql.types.StructType; @@ -43,14 +43,14 @@ final class UnsafeExternalRowSorter { private final UnsafeRowConverter rowConverter; private final RowComparator rowComparator; private final PrefixComparator prefixComparator; - private final Function1 prefixComputer; + private final Function1 prefixComputer; public UnsafeExternalRowSorter( StructType schema, - Ordering ordering, + Ordering ordering, PrefixComparator prefixComparator, // TODO: if possible, avoid this boxing of the return value - Function1 prefixComputer) { + Function1 prefixComputer) { this.schema = schema; this.rowConverter = new UnsafeRowConverter(schema); this.rowComparator = new RowComparator(ordering, schema); @@ -58,7 +58,7 @@ public UnsafeExternalRowSorter( this.prefixComputer = prefixComputer; } - public Iterator sort(Iterator inputIterator) throws IOException { + public Iterator sort(Iterator inputIterator) throws IOException { final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); byte[] rowConversionBuffer = new byte[1024 * 8]; @@ -74,7 +74,7 @@ public Iterator sort(Iterator inputIterator) throws IOException { ); try { while (inputIterator.hasNext()) { - final Row row = inputIterator.next(); + final InternalRow row = inputIterator.next(); final int sizeRequirement = rowConverter.getSizeRequirement(row); if (sizeRequirement > rowConversionBuffer.length) { rowConversionBuffer = new byte[sizeRequirement]; @@ -108,7 +108,7 @@ public boolean hasNext() { } @Override - public Row next() { + public InternalRow next() { try { sortedIterator.loadNext(); if (hasNext()) { @@ -150,12 +150,12 @@ public Row next() { private static final class RowComparator extends RecordComparator { private final StructType schema; - private final Ordering ordering; + private final Ordering ordering; private final int numFields; private final UnsafeRow row1 = new UnsafeRow(); private final UnsafeRow row2 = new UnsafeRow(); - public RowComparator(Ordering ordering, StructType schema) { + public RowComparator(Ordering ordering, StructType schema) { this.schema = schema; this.numFields = schema.length(); this.ordering = ordering; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala index 38d0b6ad25c9a..cfefb13e7721e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow + /** * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to * `Row` in order to work around a spurious IntelliJ compiler error. */ -private[spark] abstract class AbstractScalaRowIterator extends Iterator[Row] +private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 8372bd0810234..6d0c97e5e23dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import scala.util.control.NonFatal -import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer @@ -35,16 +34,6 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} -object Exchange { - /** - * Returns true when the ordering expressions are a subset of the key. - * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. - */ - def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { - desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) - } -} - /** * :: DeveloperApi :: * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each @@ -194,9 +183,6 @@ case class Exchange( } } val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - if (newOrdering.nonEmpty) { - shuffled.setKeyOrdering(keyOrdering) - } shuffled.setSerializer(serializer) shuffled.map(_._2) @@ -317,23 +303,20 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child } - val withSort = if (needSort) { - // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow - // supports the given schema. - val supportsUnsafeRowConversion: Boolean = try { - new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray) - true - } catch { - case NonFatal(e) => - false - } - if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { - UnsafeExternalSort(rowOrdering, global = false, withShuffle) - } else if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) - } else { - Sort(rowOrdering, global = false, withShuffle) - } + val withSort = if (needSort) { + // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow + // supports the given schema. + val supportsUnsafeRowConversion: Boolean = try { + new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray) + true + } catch { + case NonFatal(e) => + false + } + if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { + UnsafeExternalSort(rowOrdering, global = false, withShuffle) + } else if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) } else { Sort(rowOrdering, global = false, withShuffle) } @@ -364,18 +347,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (UnspecifiedDistribution, Seq(), child) => child case (UnspecifiedDistribution, rowOrdering, child) => - // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow - // supports the given schema. - val supportsUnsafeRowConversion: Boolean = try { - new UnsafeRowConverter(child.schema.map(_.dataType).toArray) - true - } catch { - case NonFatal(e) => - false - } - if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { - UnsafeExternalSort(rowOrdering, global = false, child) - } else if (sqlContext.conf.externalSortEnabled) { + if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, child) } else { Sort(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2ae683d601ede..262e14d7a859e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -268,15 +268,15 @@ case class UnsafeExternalSort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { assert (codegenEnabled) - def doSort(iterator: Iterator[Row]): Iterator[Row] = { + def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val prefixComparator = new PrefixComparator { override def compare(prefix1: Long, prefix2: Long): Int = 0 } // TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation. - def prefixComputer(row: Row): Long = 0 + def prefixComputer(row: InternalRow): Long = 0 new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) } child.execute().mapPartitions(doSort, preservesPartitioning = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 0699650102a6f..2abe65a71813d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -22,7 +22,6 @@ import java.util.NoSuchElementException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -64,24 +63,24 @@ case class SortMergeJoin( val rightResults = right.execute().map(_.copy()) leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[Row] { + new Iterator[InternalRow] { // Mutable per row objects. private[this] val joinRow = new JoinedRow5 - private[this] var leftElement: Row = _ - private[this] var rightElement: Row = _ - private[this] var leftKey: Row = _ - private[this] var rightKey: Row = _ - private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var leftElement: InternalRow = _ + private[this] var rightElement: InternalRow = _ + private[this] var leftKey: InternalRow = _ + private[this] var rightKey: InternalRow = _ + private[this] var rightMatches: CompactBuffer[InternalRow] = _ private[this] var rightPosition: Int = -1 private[this] var stop: Boolean = false - private[this] var matchKey: Row = _ + private[this] var matchKey: InternalRow = _ // initialize iterator initialize() override final def hasNext: Boolean = nextMatchingPair() - override final def next(): Row = { + override final def next(): InternalRow = { if (hasNext) { // we are using the buffered right rows and run down left iterator val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) @@ -144,7 +143,7 @@ case class SortMergeJoin( fetchLeft() } } - rightMatches = new CompactBuffer[Row]() + rightMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches From 7eafecf5634e8e2cc64cee01af8fa94c294f6e7c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 11:36:20 -0700 Subject: [PATCH 13/62] Port test to SparkPlanTest --- .../spark/sql/execution/basicOperators.scala | 2 +- .../spark/sql/UnsafeSortMergeJoinSuite.scala | 52 ------------------ .../execution/UnsafeExternalSortSuite.scala | 55 ++++++++++--------- 3 files changed, 29 insertions(+), 80 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 262e14d7a859e..bdc3d0037ce14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -269,7 +269,7 @@ case class UnsafeExternalSort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert (codegenEnabled) + assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val prefixComparator = new PrefixComparator { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala deleted file mode 100644 index 4ce0537f02418..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.execution.UnsafeExternalSort -import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.scalatest.BeforeAndAfterEach - -class UnsafeSortMergeJoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData - - conf.setConf(SQLConf.SORTMERGE_JOIN, "true") - conf.setConf(SQLConf.CODEGEN_ENABLED, "true") - conf.setConf(SQLConf.UNSAFE_ENABLED, "true") - conf.setConf(SQLConf.EXTERNAL_SORT, "true") - conf.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, "-1") - - test("basic sort merge join test") { - val df = upperCaseData.join(lowerCaseData, $"n" === $"N") - print(df.queryExecution.optimizedPlan) - assert(df.queryExecution.sparkPlan.collect { - case smj: UnsafeSortMergeJoin => smj - }.nonEmpty) - checkAnswer( - df, - Seq( - Row(1, "A", 1, "a"), - Row(2, "B", 2, "b"), - Row(3, "C", 3, "c"), - Row(4, "D", 4, "d") - )) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index d418638c30eb5..98145d25e32db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -17,47 +17,48 @@ package org.apache.spark.sql.execution -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext -class UnsafeExternalSortSuite extends SparkFunSuite with Matchers { +import scala.util.Random + +class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } private def createRow(values: Any*): Row = { new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) } test("basic sorting") { - val sc = TestSQLContext.sparkContext - val sqlContext = new SQLContext(sc) - sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true") - val schema: StructType = StructType( - StructField("word", StringType, nullable = false) :: - StructField("number", IntegerType, nullable = false) :: Nil) + val inputData = Seq( + ("Hello", 9), + ("World", 4), + ("Hello", 7), + ("Skinny", 0), + ("Constantinople", 9) + ) + val sortOrder: Seq[SortOrder] = Seq( SortOrder(BoundReference(0, StringType, nullable = false), Ascending), SortOrder(BoundReference(1, IntegerType, nullable = false), Descending)) - val rowsToSort: Seq[Row] = Seq( - createRow("Hello", 9), - createRow("World", 4), - createRow("Hello", 7), - createRow("Skinny", 0), - createRow("Constantinople", 9)) - SparkPlan.currentContext.set(sqlContext) - val input = - new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1)) - // Treat the existing sort operators as the source-of-truth for this test - val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect() - val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect() - val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect() - assert (defaultSorted === externalSorted) - assert (unsafeSorted === externalSorted) + + checkAnswer( + Random.shuffle(inputData), + (input: SparkPlan) => new UnsafeExternalSort(sortOrder, global = false, input), + inputData + ) } } From 21d7d93c69a357aed4f11a2e7f8b04ba2f0d910c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 11:38:57 -0700 Subject: [PATCH 14/62] Back out of BlockObjectWriter change --- .../unsafe/sort/UnsafeSorterSpillWriter.java | 66 ++++++++++++++----- .../spark/storage/BlockObjectWriter.scala | 9 +-- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 33eb37fbeeba2..b1c9e2101c164 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -17,7 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; -import java.io.*; +import java.io.File; +import java.io.IOException; import scala.Tuple2; @@ -31,15 +32,18 @@ final class UnsafeSorterSpillWriter { - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; static final int EOF_MARKER = -1; - private byte[] arr = new byte[SER_BUFFER_SIZE]; + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. This array does not need to be large enough to hold a single + // record; + private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; private final File file; private final BlockId blockId; private BlockObjectWriter writer; - private DataOutputStream dos; public UnsafeSorterSpillWriter( BlockManager blockManager, @@ -55,7 +59,26 @@ public UnsafeSorterSpillWriter( // around this, we pass a dummy no-op serializer. writer = blockManager.getDiskWriter( blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); - dos = new DataOutputStream(writer); + } + + // Based on DataOutputStream.writeLong. + private void writeLongToBuffer(long v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 56); + writeBuffer[offset + 1] = (byte)(v >>> 48); + writeBuffer[offset + 2] = (byte)(v >>> 40); + writeBuffer[offset + 3] = (byte)(v >>> 32); + writeBuffer[offset + 4] = (byte)(v >>> 24); + writeBuffer[offset + 5] = (byte)(v >>> 16); + writeBuffer[offset + 6] = (byte)(v >>> 8); + writeBuffer[offset + 7] = (byte)(v >>> 0); + } + + // Based on DataOutputStream.writeInt. + private void writeIntToBuffer(int v, int offset) throws IOException { + writeBuffer[offset + 0] = (byte)(v >>> 24); + writeBuffer[offset + 1] = (byte)(v >>> 16); + writeBuffer[offset + 2] = (byte)(v >>> 8); + writeBuffer[offset + 3] = (byte)(v >>> 0); } public void write( @@ -63,24 +86,33 @@ public void write( long baseOffset, int recordLength, long keyPrefix) throws IOException { - dos.writeInt(recordLength); - dos.writeLong(keyPrefix); - PlatformDependent.copyMemory( - baseObject, - baseOffset + 4, - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); - writer.write(arr, 0, recordLength); + writeIntToBuffer(recordLength, 0); + writeLongToBuffer(keyPrefix, 4); + int dataRemaining = recordLength; + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; + long recordReadPosition = baseOffset + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer), + toTransfer); + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; + } writer.recordWritten(); } public void close() throws IOException { - dos.writeInt(EOF_MARKER); + writeIntToBuffer(EOF_MARKER, 0); + writer.write(writeBuffer, 0, 4); writer.commitAndClose(); writer = null; - dos = null; - arr = null; + writeBuffer = null; } public long numberOfSpilledBytes() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index eef3c25a97b4e..7eeabd1e0489c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -211,14 +211,7 @@ private[spark] class DiskBlockObjectWriter( recordWritten() } - override def write(b: Int): Unit = { - // TOOD: re-enable the `throw new UnsupportedOperationException()` here - if (!initialized) { - open() - } - - bs.write(b) - } + override def write(b: Int): Unit = throw new UnsupportedOperationException() override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { From 62f0bb80fa1c9edd55f4bc4d9b970936f1508fb9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 15:18:35 -0700 Subject: [PATCH 15/62] Update to reflect SparkPlanTest changes --- .../execution/UnsafeExternalSortSuite.scala | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 98145d25e32db..7cdf313383d7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{SQLConf, Row} -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.TestSQLContext import scala.util.Random @@ -37,28 +35,23 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - private def createRow(values: Any*): Row = { - new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) - } - test("basic sorting") { - - val inputData = Seq( - ("Hello", 9), - ("World", 4), - ("Hello", 7), - ("Skinny", 0), - ("Constantinople", 9) + val input = Seq( + ("Hello", 9, 1.0), + ("World", 4, 2.0), + ("Hello", 7, 8.1), + ("Skinny", 0, 2.2), + ("Constantinople", 9, 1.1) ) - val sortOrder: Seq[SortOrder] = Seq( - SortOrder(BoundReference(0, StringType, nullable = false), Ascending), - SortOrder(BoundReference(1, IntegerType, nullable = false), Descending)) + checkAnswer( + Random.shuffle(input).toDF("a", "b", "c"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), + input.sorted) checkAnswer( - Random.shuffle(inputData), - (input: SparkPlan) => new UnsafeExternalSort(sortOrder, global = false, input), - inputData - ) + Random.shuffle(input).toDF("a", "b", "c"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), + input.sortBy(t => (t._2, t._1))) } } From 26c89312e97997670a9ad643bcb0b796d8aab3a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 15:21:10 -0700 Subject: [PATCH 16/62] Back out some Hive changes that aren't needed anymore --- .../SortMergeCompatibilitySuite.scala | 1 - .../UnsafeSortMergeCompatibiltySuite.scala | 41 ------------------- 2 files changed, 42 deletions(-) delete mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index 7129007cdfbc2..f458567e5d7ea 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.test.TestSQLContext._ /** * Runs the test cases that are included in the hive distribution with sort merge join is true. diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala deleted file mode 100644 index 093ce3504e2c2..0000000000000 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive - -/** - * Runs the test cases that are included in the hive distribution with sort merge join and - * unsafe external sort enabled. - */ -class UnsafeSortMergeCompatibiltySuite extends SortMergeCompatibilitySuite { - override def beforeAll() { - super.beforeAll() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, "true") - TestHive.setConf(SQLConf.UNSAFE_ENABLED, "true") - TestHive.setConf(SQLConf.EXTERNAL_SORT, "true") - } - - override def afterAll() { - TestHive.setConf(SQLConf.CODEGEN_ENABLED, "false") - TestHive.setConf(SQLConf.UNSAFE_ENABLED, "false") - TestHive.setConf(SQLConf.EXTERNAL_SORT, "false") - super.afterAll() - } -} \ No newline at end of file From 206bfa2002f40e0dc5fc19bf85a58799b0447829 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 15:23:06 -0700 Subject: [PATCH 17/62] Add some missing newlines at the ends of files --- .../spark/util/collection/unsafe/sort/UnsafeSorterIterator.java | 2 +- .../util/collection/unsafe/sort/UnsafeSorterSpillReader.java | 2 +- .../util/collection/unsafe/sort/UnsafeExternalSorterSuite.java | 2 +- .../util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java index f63ded826afb4..16ac2e8d821ba 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -32,4 +32,4 @@ public abstract class UnsafeSorterIterator { public abstract int getRecordLength(); public abstract long getKeyPrefix(); -} \ No newline at end of file +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index bee555ad9a909..c42d698347e88 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -86,4 +86,4 @@ public int getRecordLength() { public long getKeyPrefix() { return keyPrefix; } -} \ No newline at end of file +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index ea746e03471be..8a200b5aa2d4c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -180,4 +180,4 @@ public void testSortingOnlyByPartitionId() throws Exception { // assert(tempDir.isEmpty) } -} \ No newline at end of file +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 095d9ae5975d4..0dfd73039137f 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -139,4 +139,4 @@ public int compare(long prefix1, long prefix2) { } assertEquals(dataToSort.length, iterLength); } -} \ No newline at end of file +} From ebf9eeafb0ae8a1cbf69acff20fbc77e4bcbc695 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 15:39:05 -0700 Subject: [PATCH 18/62] Harmonization with shuffle's unsafe sorter --- .../unsafe/sort/UnsafeExternalSorter.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index d0ede69e9e66c..a58a6bcb23bdd 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -26,6 +26,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,7 +41,7 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + private static final int PAGE_SIZE = 1 << 27; // 128 megabytes private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; @@ -107,6 +108,12 @@ private void openSorter() throws IOException { @VisibleForTesting public void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + numSpills, + numSpills > 1 ? " times" : " time"); + final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); spillWriters.add(spillWriter); @@ -129,14 +136,13 @@ public void spill() throws IOException { taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes()); numSpills++; - final long threadId = Thread.currentThread().getId(); - // TODO: messy; log _before_ spill - logger.info("Thread " + threadId + " spilling in-memory map of " + - org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" + - (numSpills + ((numSpills > 1) ? " times" : " time")) + " so far)"); openSorter(); } + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + public long freeMemory() { long memoryFreed = 0; final Iterator iter = allocatedPages.iterator(); From 1db845a097ae3882d9d84b3dabe9b9668212c2b3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 16:18:16 -0700 Subject: [PATCH 19/62] Many more changes to harmonize with shuffle sorter --- .../unsafe/sort/UnsafeExternalSorter.java | 169 +++++++++++------- .../unsafe/sort/UnsafeInMemorySorter.java | 43 ++--- 2 files changed, 125 insertions(+), 87 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index a58a6bcb23bdd..5a63e72940357 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -17,7 +17,13 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.io.IOException; +import java.util.LinkedList; + import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -27,12 +33,6 @@ import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.Iterator; -import java.util.LinkedList; /** * External sorter based on {@link UnsafeInMemorySorter}. @@ -42,28 +42,36 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); private static final int PAGE_SIZE = 1 << 27; // 128 megabytes + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; - private int numSpills = 0; - private UnsafeInMemorySorter sorter; - private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; - private final LinkedList allocatedPages = new LinkedList(); - private final boolean spillingEnabled; - private final int fileBufferSize; private ShuffleWriteMetrics writeMetrics; + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + // These variables are reset after spilling: + private UnsafeInMemorySorter sorter; private MemoryBlock currentPage = null; private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; - private final LinkedList spillWriters = - new LinkedList(); + private final LinkedList spillWriters = new LinkedList<>(); public UnsafeExternalSorter( TaskMemoryManager memoryManager, @@ -81,41 +89,44 @@ public UnsafeExternalSorter( this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; this.initialSize = initialSize; - this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - openSorter(); + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + initializeForWriting(); } // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. - private void openSorter() throws IOException { + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); - // TODO: connect write metrics to task metrics? // TODO: move this sizing calculation logic into a static method of sorter: final long memoryRequested = initialSize * 8L * 2; - if (spillingEnabled) { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); - if (memoryAcquired != memoryRequested) { - shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire memory!"); - } + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } this.sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); } + /** + * Sort and spill the current records in response to memory pressure. + */ @VisibleForTesting - public void spill() throws IOException { + void spill() throws IOException { logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), - numSpills, - numSpills > 1 ? " times" : " time"); + spillWriters.size(), + spillWriters.size() > 1 ? " times" : " time"); final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics); spillWriters.add(spillWriter); final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); while (sortedRecords.hasNext()) { @@ -134,9 +145,7 @@ public void spill() throws IOException { shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes()); - numSpills++; - openSorter(); + initializeForWriting(); } private long getMemoryUsage() { @@ -145,72 +154,98 @@ private long getMemoryUsage() { public long freeMemory() { long memoryFreed = 0; - final Iterator iter = allocatedPages.iterator(); - while (iter.hasNext()) { - memoryManager.freePage(iter.next()); - shuffleMemoryManager.release(PAGE_SIZE); - memoryFreed += PAGE_SIZE; - iter.remove(); + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); } + allocatedPages.clear(); currentPage = null; currentPagePosition = -1; + freeSpaceInCurrentPage = 0; return memoryFreed; } - private void ensureSpaceInDataPage(int requiredSpace) throws IOException { + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. - if (!sorter.hasSpaceForAnotherRecord() && spillingEnabled) { - final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); - final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer); - if (memoryAcquired < memoryToGrowSortBuffer) { + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { shuffleMemoryManager.release(memoryAcquired); spill(); } else { - sorter.expandSortBuffer(); - shuffleMemoryManager.release(oldSortBufferMemoryUsage); + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } - final long spaceInCurrentPage; - if (currentPage != null) { - spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); - } else { - spaceInCurrentPage = 0; - } - if (requiredSpace > PAGE_SIZE) { - // TODO: throw a more specific exception? - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); - } else if (requiredSpace > spaceInCurrentPage) { - if (spillingEnabled) { + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); if (memoryAcquired < PAGE_SIZE) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpill != PAGE_SIZE) { - shuffleMemoryManager.release(memoryAcquiredAfterSpill); - throw new IOException("Can't allocate memory!"); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); } } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); } - currentPage = memoryManager.allocatePage(PAGE_SIZE); - currentPagePosition = currentPage.getBaseOffset(); - allocatedPages.add(currentPage); - logger.info("Acquired new page! " + allocatedPages.size() * PAGE_SIZE); } } + /** + * Write a record to the sorter. + */ public void insertRecord( Object recordBaseObject, long recordBaseOffset, int lengthInBytes, long prefix) throws IOException { // Need 4 bytes to store the record length. - ensureSpaceInDataPage(lengthInBytes + 4); + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } final long recordAddress = memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 5ddabb63e649d..c084290ba6117 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -70,12 +70,12 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] sortBuffer; + private long[] pointerArray; /** * The position in the sort buffer where new records can be inserted. */ - private int sortBufferInsertPosition = 0; + private int pointerArrayInsertPosition = 0; public UnsafeInMemorySorter( final TaskMemoryManager memoryManager, @@ -83,39 +83,42 @@ public UnsafeInMemorySorter( final PrefixComparator prefixComparator, int initialSize) { assert (initialSize > 0); - this.sortBuffer = new long[initialSize * 2]; + this.pointerArray = new long[initialSize * 2]; this.memoryManager = memoryManager; - this.sorter = new Sorter(UnsafeSortDataFormat.INSTANCE); + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); } public long getMemoryUsage() { - return sortBuffer.length * 8L; + return pointerArray.length * 8L; } public boolean hasSpaceForAnotherRecord() { - return sortBufferInsertPosition + 2 < sortBuffer.length; + return pointerArrayInsertPosition + 2 < pointerArray.length; } - public void expandSortBuffer() { - final long[] oldBuffer = sortBuffer; - sortBuffer = new long[oldBuffer.length * 2]; - System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); } /** - * Insert a record into the sort buffer. + * Inserts a record to be sorted. * - * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix */ - public void insertRecord(long objectAddress, long keyPrefix) { + public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { - expandSortBuffer(); + expandPointerArray(); } - sortBuffer[sortBufferInsertPosition] = objectAddress; - sortBufferInsertPosition++; - sortBuffer[sortBufferInsertPosition] = keyPrefix; - sortBufferInsertPosition++; + pointerArray[pointerArrayInsertPosition] = recordPointer; + pointerArrayInsertPosition++; + pointerArray[pointerArrayInsertPosition] = keyPrefix; + pointerArrayInsertPosition++; } private static final class SortedIterator extends UnsafeSorterIterator { @@ -171,7 +174,7 @@ public void loadNext() { * {@code next()} will return the same mutable object. */ public UnsafeSorterIterator getSortedIterator() { - sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); + return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); } } From 82bb0ece6b0a3b16feb53b37dcb722c960773d3b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 16:24:23 -0700 Subject: [PATCH 20/62] Fix IntelliJ complaint due to negated if condition --- .../util/collection/unsafe/sort/UnsafeExternalSorter.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 5a63e72940357..64e4e74b0cf37 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -265,7 +265,9 @@ public void insertRecord( public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); - if (!spillWriters.isEmpty()) { + if (spillWriters.isEmpty()) { + return inMemoryIterator; + } else { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { @@ -276,8 +278,6 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { spillMerger.addSpill(inMemoryIterator); } return spillMerger.getSortedIterator(); - } else { - return inMemoryIterator; } } } From 9869ec298c82728dfbea6884fa94b9f725a43089 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 16:35:31 -0700 Subject: [PATCH 21/62] Clean up Exchange code a bit --- .../catalyst/expressions/UnsafeRowConverter.scala | 15 +++++++++++++++ .../org/apache/spark/sql/execution/Exchange.scala | 14 ++------------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index b11fc245c4af9..d7cc672bb0f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent @@ -101,6 +103,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } +object UnsafeRowConverter { + def supportsSchema(schema: StructType): Boolean = { + schema.forall { field => + try { + UnsafeColumnWriter.forType(field.dataType) + true + } catch { + case e: UnsupportedOperationException => false + } + } + } +} + /** * Function for writing a column into an UnsafeRow. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6d0c97e5e23dc..9fbdfdf598897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.util.control.NonFatal - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer @@ -304,16 +302,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } val withSort = if (needSort) { - // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow - // supports the given schema. - val supportsUnsafeRowConversion: Boolean = try { - new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray) - true - } catch { - case NonFatal(e) => - false - } - if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { + if (sqlContext.conf.unsafeEnabled + && UnsafeRowConverter.supportsSchema(withShuffle.schema)) { UnsafeExternalSort(rowOrdering, global = false, withShuffle) } else if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, withShuffle) From 6d6a1e65e8abf6d02ae1c86fe70bc3f92800b2dc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 16:55:41 -0700 Subject: [PATCH 22/62] Centralize logic for picking sort operator implementations --- .../apache/spark/sql/execution/Exchange.scala | 16 +++----------- .../spark/sql/execution/SparkStrategies.scala | 22 +++++++++++++++---- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 9fbdfdf598897..231a72e91263d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -302,14 +302,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } val withSort = if (needSort) { - if (sqlContext.conf.unsafeEnabled - && UnsafeRowConverter.supportsSchema(withShuffle.schema)) { - UnsafeExternalSort(rowOrdering, global = false, withShuffle) - } else if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) - } else { - Sort(rowOrdering, global = false, withShuffle) - } + sqlContext.planner.BasicOperators.getSortOperator( + rowOrdering, global = false, withShuffle) } else { withShuffle } @@ -337,11 +331,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (UnspecifiedDistribution, Seq(), child) => child case (UnspecifiedDistribution, rowOrdering, child) => - if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, child) - } else { - Sort(rowOrdering, global = false, child) - } + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) case (dist, ordering, _) => sys.error(s"Don't know how to ensure $dist with ordering $ordering") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5daf86d817586..07206da424d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -291,6 +291,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions + /** + * Picks an appropriate sort operator. + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ + def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { + if (sqlContext.conf.unsafeEnabled && UnsafeRowConverter.supportsSchema(child.schema)) { + execution.UnsafeExternalSort(sortExprs, global, child) + } else if (sqlContext.conf.externalSortEnabled) { + execution.ExternalSort(sortExprs, global, child) + } else { + execution.Sort(sortExprs, global, child) + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil @@ -302,11 +318,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. - execution.Sort(sortExprs, global = false, planLater(child)) :: Nil - case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled => - execution.ExternalSort(sortExprs, global, planLater(child)):: Nil + getSortOperator(sortExprs, global = false, planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => - execution.Sort(sortExprs, global, planLater(child)):: Nil + getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => From 90c2b6ab9812777fa08d5364b13246abe93d90f3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 16:03:27 -0700 Subject: [PATCH 23/62] Update test name --- .../collection/unsafe/sort/UnsafeInMemorySorterSuite.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 0dfd73039137f..e4c678ce2b5ef 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -56,11 +56,8 @@ public void testSortingEmptyInput() { assert(!iter.hasNext()); } - /** - * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. - */ @Test - public void testSortingOnlyByPartitionId() throws Exception { + public void testSortingOnlyByIntegerPrefix() throws Exception { final String[] dataToSort = new String[] { "Boba", "Pearls", From 41b88818ede3f9a529042b3cdd3c31fddada3825 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 28 Jun 2015 00:07:44 -0700 Subject: [PATCH 24/62] Get UnsafeInMemorySorterSuite to pass (WIP) --- core/pom.xml | 20 +++++++++---------- .../sort/UnsafeInMemorySorterSuite.java | 18 +++++++++-------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index aee0d92620606..558cc3fb9f2f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -343,28 +343,28 @@ test - org.mockito - mockito-core + org.hamcrest + hamcrest-core test - org.scalacheck - scalacheck_${scala.binary.version} + org.hamcrest + hamcrest-library test - junit - junit + org.mockito + mockito-core test - org.hamcrest - hamcrest-core + org.scalacheck + scalacheck_${scala.binary.version} test - org.hamcrest - hamcrest-library + junit + junit test diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index e4c678ce2b5ef..67666e35aaeb9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -35,11 +35,11 @@ public class UnsafeInMemorySorterSuite { private static String getStringFromDataPage(Object baseObject, long baseOffset) { - final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final int strLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); final byte[] strBytes = new byte[strLength]; PlatformDependent.copyMemory( baseObject, - baseOffset + 8, + baseOffset + 4, strBytes, PlatformDependent.BYTE_ARRAY_OFFSET, strLength); return new String(strBytes); @@ -77,8 +77,8 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { long position = dataPage.getBaseOffset(); for (String str : dataToSort) { final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); - position += 8; + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; PlatformDependent.copyMemory( strBytes, PlatformDependent.BYTE_ARRAY_OFFSET, @@ -114,12 +114,12 @@ public int compare(long prefix1, long prefix2) { position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { // position now points to the start of a record (which holds its length). - final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); + final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position); final int partitionId = hashPartitioner.getPartition(str); sorter.insertRecord(address, partitionId); - position += 8 + recordLength; + position += 4 + recordLength; } final UnsafeSorterIterator iter = sorter.getSortedIterator(); int iterLength = 0; @@ -127,9 +127,11 @@ public int compare(long prefix1, long prefix2) { Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); - final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset()); + // TODO: the logic for how we manipulate record length offsets here is confusing; clean + // this up and clarify it in comments. + final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset() - 4); final long keyPrefix = iter.getKeyPrefix(); - assertTrue(Arrays.asList(dataToSort).contains(str)); + assertThat(str, isIn(Arrays.asList(dataToSort))); assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); prevPrefix = keyPrefix; iterLength++; From 7f875f9e6c1397883babf3934cea5d60dda88597 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 30 Jun 2015 15:19:25 -0700 Subject: [PATCH 25/62] Commit failing test demonstrating bug in handling objects in spills --- .../unsafe/sort/UnsafeExternalSorter.java | 2 +- .../sort/UnsafeExternalSorterSuite.java | 8 +- .../execution/UnsafeExternalRowSorter.java | 112 ++++++++++-------- .../spark/sql/execution/basicOperators.scala | 3 +- .../execution/UnsafeExternalSortSuite.scala | 50 +++++++- 5 files changed, 116 insertions(+), 59 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 64e4e74b0cf37..51382f52d124b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -118,7 +118,7 @@ private void initializeForWriting() throws IOException { * Sort and spill the current records in response to memory pressure. */ @VisibleForTesting - void spill() throws IOException { + public void spill() throws IOException { logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 8a200b5aa2d4c..0f8d5c38b5637 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -38,7 +38,6 @@ import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -56,8 +55,6 @@ public class UnsafeExternalSorterSuite { final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -138,11 +135,8 @@ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); } - /** - * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. - */ @Test - public void testSortingOnlyByPartitionId() throws Exception { + public void testSortingOnlyByPrefix() throws Exception { final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 7c53ea7bdac24..064459e4d2568 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -24,12 +24,15 @@ import scala.collection.Iterator; import scala.math.Ordering; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; +import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -41,61 +44,70 @@ final class UnsafeExternalRowSorter { private final StructType schema; private final UnsafeRowConverter rowConverter; - private final RowComparator rowComparator; - private final PrefixComparator prefixComparator; private final Function1 prefixComputer; + private final ObjectPool objPool = new ObjectPool(128); + private final UnsafeExternalSorter sorter; + private byte[] rowConversionBuffer = new byte[1024 * 8]; public UnsafeExternalRowSorter( StructType schema, Ordering ordering, PrefixComparator prefixComparator, // TODO: if possible, avoid this boxing of the return value - Function1 prefixComputer) { + Function1 prefixComputer) throws IOException { this.schema = schema; this.rowConverter = new UnsafeRowConverter(schema); - this.rowComparator = new RowComparator(ordering, schema); - this.prefixComparator = prefixComparator; this.prefixComputer = prefixComputer; - } - - public Iterator sort(Iterator inputIterator) throws IOException { final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); - byte[] rowConversionBuffer = new byte[1024 * 8]; - final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + sorter = new UnsafeExternalSorter( taskContext.taskMemoryManager(), sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, - rowComparator, + new RowComparator(ordering, schema.length(), objPool), prefixComparator, 4096, sparkEnv.conf() ); + } + + @VisibleForTesting + void insertRow(InternalRow row) throws IOException { + final int sizeRequirement = rowConverter.getSizeRequirement(row); + if (sizeRequirement > rowConversionBuffer.length) { + rowConversionBuffer = new byte[sizeRequirement]; + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero + // out the portion of the buffer that we'll actually write to. + Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0); + } + final int bytesWritten = rowConverter.writeRow( + row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool); + assert (bytesWritten == sizeRequirement); + final long prefix = prefixComputer.apply(row); + sorter.insertRecord( + rowConversionBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeRequirement, + prefix + ); + } + + @VisibleForTesting + void spill() throws IOException { + sorter.spill(); + } + + private void cleanupResources() { + sorter.freeMemory(); + } + + @VisibleForTesting + Iterator sort() throws IOException { try { - while (inputIterator.hasNext()) { - final InternalRow row = inputIterator.next(); - final int sizeRequirement = rowConverter.getSizeRequirement(row); - if (sizeRequirement > rowConversionBuffer.length) { - rowConversionBuffer = new byte[sizeRequirement]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero - // out the portion of the buffer that we'll actually write to. - Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0); - } - final int bytesWritten = - rowConverter.writeRow(row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET); - assert (bytesWritten == sizeRequirement); - final long prefix = prefixComputer.apply(row); - sorter.insertRecord( - rowConversionBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeRequirement, - prefix - ); - } final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); return new AbstractScalaRowIterator() { @@ -113,7 +125,7 @@ public InternalRow next() { sortedIterator.loadNext(); if (hasNext()) { row.pointTo( - sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, schema); + sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool); return row; } else { final byte[] rowDataCopy = new byte[sortedIterator.getRecordLength()]; @@ -125,14 +137,12 @@ public InternalRow next() { sortedIterator.getRecordLength() ); row.backingArray = rowDataCopy; - row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema); + row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, objPool); sorter.freeMemory(); return row; } } catch (IOException e) { - // TODO: we need to ensure that files are cleaned properly after an exception, - // so we need better cleanup methods than freeMemory(). - sorter.freeMemory(); + cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack // to re-throw the exception: PlatformDependent.throwException(e); @@ -141,30 +151,36 @@ public InternalRow next() { }; }; } catch (IOException e) { - // TODO: we need to ensure that files are cleaned properly after an exception, - // so we need better cleanup methods than freeMemory(). - sorter.freeMemory(); + cleanupResources(); throw e; } } + + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); + } + private static final class RowComparator extends RecordComparator { - private final StructType schema; private final Ordering ordering; private final int numFields; + private final ObjectPool objPool; private final UnsafeRow row1 = new UnsafeRow(); private final UnsafeRow row2 = new UnsafeRow(); - public RowComparator(Ordering ordering, StructType schema) { - this.schema = schema; - this.numFields = schema.length(); + public RowComparator(Ordering ordering, int numFields, ObjectPool objPool) { + this.numFields = numFields; this.ordering = ordering; + this.objPool = objPool; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1, numFields, schema); - row2.pointTo(baseObj2, baseOff2, numFields, schema); + row1.pointTo(baseObj1, baseOff1, numFields, objPool); + row2.pointTo(baseObj2, baseOff2, numFields, objPool); return ordering.compare(row1, row2); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index bdc3d0037ce14..8d26d661d2b2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.unsafe.sort.PrefixComparator -import org.apache.spark.{SparkEnv, HashPartitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -275,7 +274,7 @@ case class UnsafeExternalSort( val prefixComparator = new PrefixComparator { override def compare(prefix1: Long, prefix2: Long): Int = 0 } - // TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation. + // TODO: do real prefix comparison. For dev/testing purposes, this is a dummy implementation. def prefixComputer(row: InternalRow): Long = 0 new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 7cdf313383d7c..14d7ecc21d14b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, AttributeReference, SortOrder} +import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext @@ -54,4 +59,47 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), input.sortBy(t => (t._2, t._1))) } + + test("sorting with object columns") { + // TODO: larger input data + val input = Seq( + Row("Hello", Row(1)), + Row("World", Row(2)) + ) + + val schema = StructType( + StructField("a", StringType, nullable = false) :: + StructField("b", StructType(StructField("b", IntegerType, nullable = false) :: Nil)) :: + Nil + ) + + // Hack so that we don't need to pass in / mock TaskContext, SparkEnv, etc. Ultimately it would + // be better to not use this hack, but due to time constraints I have deferred this for + // followup PRs. + val sortResult = TestSQLContext.sparkContext.parallelize(input, 1).mapPartitions { iter => + val rows = iter.toSeq + val sortOrder = SortOrder(BoundReference(0, StringType, nullable = false), Ascending) + + val sorter = new UnsafeExternalRowSorter( + schema, + GenerateOrdering.generate(Seq(sortOrder), schema.toAttributes), + new PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = 0 + }, + x => 0L + ) + + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + + sorter.insertRow(toCatalyst(input.head).asInstanceOf[InternalRow]) + sorter.spill() + input.tail.foreach { row => + sorter.insertRow(toCatalyst(row).asInstanceOf[InternalRow]) + } + val sortedRowsIterator = sorter.sort() + sortedRowsIterator.map(CatalystTypeConverters.convertToScala(_, schema).asInstanceOf[Row]) + }.collect() + + assert(input.sortBy(t => t.getString(0)) === sortResult) + } } From 6b156fbeb5fd5dda89919ed88654a3ce3180923d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 30 Jun 2015 16:53:31 -0700 Subject: [PATCH 26/62] Some WIP work on prefix comparison. --- .../unsafe/sort/PrefixComparator.java | 3 ++ .../unsafe/sort/PrefixComparators.java | 47 +++++++++++++++++++ .../unsafe/sort/PrefixComparatorsSuite.java | 45 ++++++++++++++++++ .../spark/sql/execution/basicOperators.scala | 2 +- 4 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java create mode 100644 core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java index c41332ad117cb..45b78829e4cf7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparator.java @@ -17,10 +17,13 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.annotation.Private; + /** * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific * comparisons, such as lexicographic comparison for strings. */ +@Private public abstract class PrefixComparator { public abstract int compare(long prefix1, long prefix2); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java new file mode 100644 index 0000000000000..475430fffe93d --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.annotation.Private; + +@Private +public class PrefixComparators { + private PrefixComparators() {} + + public static final IntPrefixComparator INTEGER = new IntPrefixComparator(); + + static final class IntPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + int a = (int) aPrefix; + int b = (int) bPrefix; + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(int value) { + return value & 0xffffffffL; + } + } + + static final class LongPrefixComparator extends PrefixComparator { + @Override + public int compare(long a, long b) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + } +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java new file mode 100644 index 0000000000000..4c4a7d5f5486a --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class PrefixComparatorsSuite { + + private static int genericComparison(Comparable a, Comparable b) { + return a.compareTo(b); + } + + @Test + public void intPrefixComparator() { + int[] testData = new int[] { 0, Integer.MIN_VALUE, Integer.MAX_VALUE, 0, 1, 2, -1, -2, 1024}; + for (int a : testData) { + for (int b : testData) { + long aPrefix = PrefixComparators.INTEGER.computePrefix(a); + long bPrefix = PrefixComparators.INTEGER.computePrefix(b); + assertEquals( + "Wrong prefix comparison results for a=" + a + " b=" + b, + genericComparison(a, b), + PrefixComparators.INTEGER.compare(aPrefix, bPrefix)); + + } + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 8d26d661d2b2c..4d58983d219da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.types.StructType -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager @@ -28,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} From d246e29e3455b4f43a4bbf2229fab5dbc7865ad5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 15:22:29 -0700 Subject: [PATCH 27/62] Fix consideration of column types when choosing sort implementation. --- .../sql/execution/UnsafeExternalRowSorter.java | 15 +++++++++++++++ .../catalyst/expressions/UnsafeRowConverter.scala | 13 ------------- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/execution/basicOperators.scala | 10 ++++++++++ 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 064459e4d2568..b4b248d215315 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -30,9 +30,12 @@ import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -164,6 +167,18 @@ public Iterator sort(Iterator inputIterator) throws IO return sort(); } + /** + * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. + */ + public static boolean supportsSchema(StructType schema) { + for (StructField field : schema.fields()) { + if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { + return false; + } + } + return true; + } + private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index d7cc672bb0f43..245aa2829130f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -103,19 +103,6 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } -object UnsafeRowConverter { - def supportsSchema(schema: StructType): Boolean = { - schema.forall { field => - try { - UnsafeColumnWriter.forType(field.dataType) - true - } catch { - case e: UnsupportedOperationException => false - } - } - } -} - /** * Function for writing a column into an UnsafeRow. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 07206da424d27..826752c26bc35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -298,7 +298,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && UnsafeRowConverter.supportsSchema(child.schema)) { + if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) { execution.UnsafeExternalSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4d58983d219da..80e9b1b0867de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -286,6 +286,16 @@ case class UnsafeExternalSort( override def outputOrdering: Seq[SortOrder] = sortOrder } +@DeveloperApi +object UnsafeExternalSort { + /** + * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. + */ + def supportsSchema(schema: StructType): Boolean = { + UnsafeExternalRowSorter.supportsSchema(schema) + } +} + /** * :: DeveloperApi :: From 6890863b40156d1e359ef84968951bb2d93fcca0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 15:24:34 -0700 Subject: [PATCH 28/62] Fix memory leak on empty inputs. --- .../spark/sql/execution/UnsafeExternalRowSorter.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index b4b248d215315..98d904a8bc2fc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -112,6 +112,11 @@ private void cleanupResources() { Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); + if (!sortedIterator.hasNext()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + } return new AbstractScalaRowIterator() { private final int numFields = schema.length(); @@ -141,7 +146,7 @@ public InternalRow next() { ); row.backingArray = rowDataCopy; row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, objPool); - sorter.freeMemory(); + cleanupResources(); return row; } } catch (IOException e) { From 4c37ba622ca5c49515f96fb6fcb5435bf873d395 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 23:51:13 -0700 Subject: [PATCH 29/62] Add tests for sorting on all primitive types. --- .../unsafe/sort/UnsafeSorterSpillWriter.java | 7 +- .../execution/UnsafeExternalRowSorter.java | 12 +- .../spark/sql/execution/basicOperators.scala | 11 +- .../spark/sql/execution/SparkPlanTest.scala | 174 +++++++++++++----- .../execution/UnsafeExternalSortSuite.scala | 92 +++------ 5 files changed, 171 insertions(+), 125 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index b1c9e2101c164..ade96396ef650 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -37,8 +37,7 @@ final class UnsafeSorterSpillWriter { // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer - // data through a byte array. This array does not need to be large enough to hold a single - // record; + // data through a byte array. private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; private final File file; @@ -115,10 +114,6 @@ public void close() throws IOException { writeBuffer = null; } - public long numberOfSpilledBytes() { - return file.length(); - } - public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { return new UnsafeSorterSpillReader(blockManager, file, blockId); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 98d904a8bc2fc..6974db83e80bc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.Arrays; -import scala.Function1; import scala.collection.Iterator; import scala.math.Ordering; @@ -47,17 +46,20 @@ final class UnsafeExternalRowSorter { private final StructType schema; private final UnsafeRowConverter rowConverter; - private final Function1 prefixComputer; + private final PrefixComputer prefixComputer; private final ObjectPool objPool = new ObjectPool(128); private final UnsafeExternalSorter sorter; private byte[] rowConversionBuffer = new byte[1024 * 8]; + public static abstract class PrefixComputer { + abstract long computePrefix(InternalRow row); + } + public UnsafeExternalRowSorter( StructType schema, Ordering ordering, PrefixComparator prefixComparator, - // TODO: if possible, avoid this boxing of the return value - Function1 prefixComputer) throws IOException { + PrefixComputer prefixComputer) throws IOException { this.schema = schema; this.rowConverter = new UnsafeRowConverter(schema); this.prefixComputer = prefixComputer; @@ -90,7 +92,7 @@ void insertRow(InternalRow row) throws IOException { final int bytesWritten = rowConverter.writeRow( row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool); assert (bytesWritten == sizeRequirement); - final long prefix = prefixComputer.apply(row); + final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 80e9b1b0867de..05fec25876416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -271,11 +270,13 @@ case class UnsafeExternalSort( assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) - val prefixComparator = new PrefixComparator { - override def compare(prefix1: Long, prefix2: Long): Int = 0 + val prefixComparator = SortPrefixUtils.getPrefixComparator(sortOrder.head) + val prefixComputer = { + val prefixComputer = SortPrefixUtils.getPrefixComputer(sortOrder.head) + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = prefixComputer(row) + } } - // TODO: do real prefix comparison. For dev/testing purposes, this is a dummy implementation. - def prefixComputer(row: InternalRow): Long = 0 new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) } child.execute().mapPartitions(doSort, preservesPartitioning = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..220946fb3292a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,18 +17,16 @@ package org.apache.spark.sql.execution -import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite - import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util._ - import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} +import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal /** * Base class for writing tests for individual physical operators. For an example of how this @@ -77,6 +75,26 @@ class SparkPlanTest extends SparkFunSuite { case None => } } + + /** + * Runs the plan and makes sure the answer matches the result produced by a reference plan. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. + */ + protected def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedPlanFunction: SparkPlan => SparkPlan): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } } /** @@ -84,6 +102,66 @@ class SparkPlanTest extends SparkFunSuite { */ object SparkPlanTest { + /** + * Runs the plan and makes sure the answer matches the result produced by a reference plan. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to + * instantiate a reference implementation of the physical operator + * that's being tested. The result of executing this plan will be + * treated as the source-of-truth for the test. + */ + def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedPlanFunction: SparkPlan => SparkPlan): Option[String] = { + + val outputPlan = planFunction(input.queryExecution.sparkPlan) + val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) + + val expectedAnswer: Seq[Row] = try { + executePlan(input, expectedOutputPlan) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan to calculate expected answer: + | $expectedOutputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + val actualAnswer: Seq[Row] = try { + executePlan(input, outputPlan) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + compareAnswers(actualAnswer, expectedAnswer).map { errorMessage => + s""" + | Results do not match. + | Actual result Spark plan: + | $outputPlan + | Expected result Spark plan: + | $expectedOutputPlan + | $errorMessage + """.stripMargin + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param input the input data to be used. @@ -98,22 +176,33 @@ object SparkPlanTest { val outputPlan = planFunction(input.queryExecution.sparkPlan) - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = outputPlan transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { - case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap + val sparkAnswer: Seq[Row] = try { + executePlan(input, outputPlan) + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } - plan.transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } + compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage => + s""" + | Results do not match for Spark plan: + | $outputPlan + | $errorMessage + """.stripMargin } + } + private def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row]): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -130,38 +219,39 @@ object SparkPlanTest { } converted.sortBy(_.toString()) } - - val sparkAnswer: Seq[Row] = try { - resolvedPlan.executeCollect().toSeq - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing Spark plan: - | $outputPlan - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { val errorMessage = s""" - | Results do not match for Spark plan: - | $outputPlan | == Results == | ${sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: + s"== Expected Answer - ${expectedAnswer.size} ==" +: prepareAnswer(expectedAnswer).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: + s"== Actual Answer - ${sparkAnswer.size} ==" +: prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} """.stripMargin - return Some(errorMessage) + Some(errorMessage) + } else { + None } + } - None + private def executePlan(input: DataFrame, outputPlan: SparkPlan): Seq[Row] = { + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { + case (a, i) => + (a.name, BoundReference(i, a.dataType, a.nullable)) + }.toMap + + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + resolvedPlan.executeCollect().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 14d7ecc21d14b..b9245d57facbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -17,18 +17,15 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering -import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, AttributeReference, SortOrder} -import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator +import scala.util.Random + import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ -import scala.util.Random class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { @@ -40,66 +37,27 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - test("basic sorting") { - val input = Seq( - ("Hello", 9, 1.0), - ("World", 4, 2.0), - ("Hello", 7, 8.1), - ("Skinny", 0, 2.2), - ("Constantinople", 9, 1.1) - ) - - checkAnswer( - Random.shuffle(input).toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), - input.sorted) - - checkAnswer( - Random.shuffle(input).toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), - input.sortBy(t => (t._2, t._1))) - } - - test("sorting with object columns") { - // TODO: larger input data - val input = Seq( - Row("Hello", Row(1)), - Row("World", Row(2)) - ) - - val schema = StructType( - StructField("a", StringType, nullable = false) :: - StructField("b", StructType(StructField("b", IntegerType, nullable = false) :: Nil)) :: - Nil - ) - - // Hack so that we don't need to pass in / mock TaskContext, SparkEnv, etc. Ultimately it would - // be better to not use this hack, but due to time constraints I have deferred this for - // followup PRs. - val sortResult = TestSQLContext.sparkContext.parallelize(input, 1).mapPartitions { iter => - val rows = iter.toSeq - val sortOrder = SortOrder(BoundReference(0, StringType, nullable = false), Ascending) - - val sorter = new UnsafeExternalRowSorter( - schema, - GenerateOrdering.generate(Seq(sortOrder), schema.toAttributes), - new PrefixComparator { - override def compare(prefix1: Long, prefix2: Long): Int = 0 - }, - x => 0L - ) - - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) - - sorter.insertRow(toCatalyst(input.head).asInstanceOf[InternalRow]) - sorter.spill() - input.tail.foreach { row => - sorter.insertRow(toCatalyst(row).asInstanceOf[InternalRow]) + // Test sorting on different data types + (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach{ dataType => + for (nullable <- Seq(true, false)) { + RandomDataGenerator.forType(dataType, nullable).foreach { randomDataGenerator => + test(s"sorting on $dataType with nullable=$nullable") { + val inputData = Seq.fill(1024)(randomDataGenerator()).filter { + case d: Double => !d.isNaN + case f: Float => !java.lang.Float.isNaN(f) + case x => true + } + val inputDf = TestSQLContext.createDataFrame( + TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + checkAnswer( + inputDf, + UnsafeExternalSort('a.asc :: Nil, global = false, _: SparkPlan), + Sort('a.asc :: Nil, global = false, _: SparkPlan) + ) + } } - val sortedRowsIterator = sorter.sort() - sortedRowsIterator.map(CatalystTypeConverters.convertToScala(_, schema).asInstanceOf[Row]) - }.collect() - - assert(input.sortBy(t => t.getString(0)) === sortResult) + } } } From 95058d9123674cc44c391d0f754959338c66b730 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 1 Jul 2015 23:57:34 -0700 Subject: [PATCH 30/62] Add missing SortPrefixUtils file --- .../spark/sql/execution/SortPrefixUtils.scala | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala new file mode 100644 index 0000000000000..0a029ab4ae0bf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator + + +object SortPrefixUtils { + + /** + * A dummy prefix comparator which always claims that prefixes are equal. This is used in cases + * where we don't know how to generate or compare prefixes for a SortOrder. + */ + private object NoOpPrefixComparator extends PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = 0 + } + + def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { + sortOrder.dataType match { + case _ => NoOpPrefixComparator + } + } + + def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { + sortOrder.dataType match { + case _ => (row: InternalRow) => 0L + } + } +} From b310c887c4738152aee58fe317b421c6672b840d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 00:10:46 -0700 Subject: [PATCH 31/62] Integrate prefix comparators for Int and Long (others coming soon) --- .../unsafe/sort/PrefixComparators.java | 5 +-- .../sql/catalyst/expressions/UnsafeRow.java | 34 ------------------- .../expressions/UnsafeRowConverter.scala | 2 -- .../spark/sql/execution/SortPrefixUtils.scala | 8 ++++- 4 files changed, 10 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 475430fffe93d..e71a31eb94487 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -24,8 +24,9 @@ public class PrefixComparators { private PrefixComparators() {} public static final IntPrefixComparator INTEGER = new IntPrefixComparator(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); - static final class IntPrefixComparator extends PrefixComparator { + public static final class IntPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { int a = (int) aPrefix; @@ -38,7 +39,7 @@ public long computePrefix(int value) { } } - static final class LongPrefixComparator extends PrefixComparator { + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 94c038c74eb3a..6c6d28341a2b3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,9 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import java.math.BigDecimal; -import java.sql.Date; -import java.util.*; import javax.annotation.Nullable; import org.apache.spark.sql.catalyst.InternalRow; @@ -286,37 +283,6 @@ public Object get(int i) { } } - /** - * Generic `get()`, for use in toString(). This method is for debugging only and is probably very - * slow to call due to having to reflect on the schema. - */ - private Object genericGet(int i) { - assertIndexIsValid(i); - assert (schema != null) : "Schema must be defined when calling genericGet()"; - final DataType dataType = schema.fields()[i].dataType(); - if (isNullAt(i) || dataType == NullType) { - return null; - } else if (dataType == StringType) { - return getUTF8String(i); - } else if (dataType == BooleanType) { - return getBoolean(i); - } else if (dataType == ByteType) { - return getByte(i); - } else if (dataType == ShortType) { - return getShort(i); - } else if (dataType == IntegerType) { - return getInt(i); - } else if (dataType == LongType) { - return getLong(i); - } else if (dataType == FloatType) { - return getFloat(i); - } else if (dataType == DoubleType) { - return getDouble(i); - } else { - throw new UnsupportedOperationException(); - } - } - @Override public boolean isNullAt(int i) { assertIndexIsValid(i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 245aa2829130f..b11fc245c4af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 0a029ab4ae0bf..3fc1d1986fb05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator +import org.apache.spark.sql.types.{LongType, IntegerType} +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} object SortPrefixUtils { @@ -35,12 +36,17 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { + case IntegerType => PrefixComparators.INTEGER + case LongType => PrefixComparators.LONG case _ => NoOpPrefixComparator } } def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { sortOrder.dataType match { + case IntegerType => (row: InternalRow) => + PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) + case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long] case _ => (row: InternalRow) => 0L } } From 66a813ee7c34c60aa3a6b6289a831a4f869d981e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 00:51:29 -0700 Subject: [PATCH 32/62] Prefix comparators for float and double --- .../unsafe/sort/PrefixComparators.java | 28 +++++++++ .../codegen/GenerateExpression.scala | 59 +++++++++++++++++++ .../expressions/CodeGenerationSuite.scala | 4 ++ .../spark/sql/execution/SortPrefixUtils.scala | 8 ++- 4 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index e71a31eb94487..c10ab26c1bd12 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -25,6 +25,8 @@ private PrefixComparators() {} public static final IntPrefixComparator INTEGER = new IntPrefixComparator(); public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); public static final class IntPrefixComparator extends PrefixComparator { @Override @@ -45,4 +47,30 @@ public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } } + + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala new file mode 100644 index 0000000000000..cb1480eb4aaf2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression} + +import scala.runtime.AbstractFunction1 + +object GenerateExpression extends CodeGenerator[Expression, InternalRow => Any] { + + override protected def canonicalize(in: Expression): Expression = { + ExpressionCanonicalizer.execute(in) + } + + override protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = { + BindReferences.bindReference(in, inputSchema) + } + + override protected def create(expr: Expression): InternalRow => Any = { + val ctx = newCodeGenContext() + val eval = expr.gen(ctx) + val code = + s""" + |class SpecificExpression extends + | ${classOf[AbstractFunction1[InternalRow, Any]].getName}<${classOf[InternalRow].getName}, Object> { + | + | @Override + | public SpecificExpression generate($exprType[] expr) { + | return new SpecificExpression(expr); + | } + | + | @Override + | public Object apply(InternalRow i) { + | ${eval.code} + | return ${eval.isNull} ? null : ${eval.primitive}; + | } + | } + """.stripMargin + logDebug(s"Generated expression '$expr':\n$code") + println(code) + compile(code).generate(ctx.references.toArray).asInstanceOf[InternalRow => Any] + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 481b335d15dfd..f8bc5c560e154 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ */ class CodeGenerationSuite extends SparkFunSuite { + test("generate expression") { + GenerateExpression.generate(Add(Literal(1), Literal(1))) + } + test("multithreaded eval") { import scala.concurrent._ import ExecutionContext.Implicits.global diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 3fc1d1986fb05..53718690f6431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.types.{LongType, IntegerType} +import org.apache.spark.sql.types.{DoubleType, FloatType, LongType, IntegerType} import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -38,6 +38,8 @@ object SortPrefixUtils { sortOrder.dataType match { case IntegerType => PrefixComparators.INTEGER case LongType => PrefixComparators.LONG + case FloatType => PrefixComparators.FLOAT + case DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator } } @@ -47,6 +49,10 @@ object SortPrefixUtils { case IntegerType => (row: InternalRow) => PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long] + case FloatType => (row: InternalRow) => + PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) + case DoubleType => (row: InternalRow) => + PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) case _ => (row: InternalRow) => 0L } } From 0dfe919659127988563202c879494459c705dd8f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 01:06:13 -0700 Subject: [PATCH 33/62] Implement prefix sort for strings (albeit inefficiently). --- .../unsafe/sort/PrefixComparators.java | 39 +++++++++++++++++++ .../spark/sql/execution/SortPrefixUtils.scala | 6 ++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index c10ab26c1bd12..a57e43ac89f7a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -17,17 +17,56 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.base.Charsets; +import com.google.common.primitives.Longs; import org.apache.spark.annotation.Private; +import org.apache.spark.unsafe.types.UTF8String; @Private public class PrefixComparators { private PrefixComparators() {} + public static final StringPrefixComparator STRING = new StringPrefixComparator(); public static final IntPrefixComparator INTEGER = new IntPrefixComparator(); public static final LongPrefixComparator LONG = new LongPrefixComparator(); public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final class StringPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + // TODO: this can certainly be done more efficiently + byte[] a = Longs.toByteArray(aPrefix); + byte[] b = Longs.toByteArray(bPrefix); + for (int i = 0; i < 8; i++) { + if (a[i] == b[i]) continue; + if (a[i] > b[i]) return -1; + else if (a[i] < b[i]) return 1; + } + return 0; + } + + public long computePrefix(UTF8String value) { + // TODO: this can certainly be done more efficiently + return value == null ? 0L : computePrefix(value.toString()); + } + + public long computePrefix(String value) { + // TODO: this can certainly be done more efficiently + if (value == null || value.length() == 0) { + return 0L; + } else { + String first4Chars = value.substring(0, Math.min(3, value.length() - 1)); + byte[] utf16Bytes = first4Chars.getBytes(Charsets.UTF_16); + byte[] padded = new byte[8]; + if (utf16Bytes.length < 8) { + System.arraycopy(utf16Bytes, 0, padded, 0, utf16Bytes.length); + } + return Longs.fromByteArray(padded); + } + } + } + public static final class IntPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 53718690f6431..8f996a2c7c182 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.types.{DoubleType, FloatType, LongType, IntegerType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -36,6 +37,7 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { + case StringType => PrefixComparators.STRING case IntegerType => PrefixComparators.INTEGER case LongType => PrefixComparators.LONG case FloatType => PrefixComparators.FLOAT @@ -46,6 +48,8 @@ object SortPrefixUtils { def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { sortOrder.dataType match { + case StringType => (row: InternalRow) => + PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) case IntegerType => (row: InternalRow) => PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long] From 939f8244e591ee491af2feb88a90d26027f75317 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 01:10:19 -0700 Subject: [PATCH 34/62] Remove code gen experiment. --- .../codegen/GenerateExpression.scala | 59 ------------------- 1 file changed, 59 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala deleted file mode 100644 index cb1480eb4aaf2..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression} - -import scala.runtime.AbstractFunction1 - -object GenerateExpression extends CodeGenerator[Expression, InternalRow => Any] { - - override protected def canonicalize(in: Expression): Expression = { - ExpressionCanonicalizer.execute(in) - } - - override protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = { - BindReferences.bindReference(in, inputSchema) - } - - override protected def create(expr: Expression): InternalRow => Any = { - val ctx = newCodeGenContext() - val eval = expr.gen(ctx) - val code = - s""" - |class SpecificExpression extends - | ${classOf[AbstractFunction1[InternalRow, Any]].getName}<${classOf[InternalRow].getName}, Object> { - | - | @Override - | public SpecificExpression generate($exprType[] expr) { - | return new SpecificExpression(expr); - | } - | - | @Override - | public Object apply(InternalRow i) { - | ${eval.code} - | return ${eval.isNull} ? null : ${eval.primitive}; - | } - | } - """.stripMargin - logDebug(s"Generated expression '$expr':\n$code") - println(code) - compile(code).generate(ctx.references.toArray).asInstanceOf[InternalRow => Any] - } -} From 5822e6fcaa8b0fb590d361b35f2ad2dda2662348 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 17:17:56 -0700 Subject: [PATCH 35/62] Fix test compilation issue --- .../spark/sql/catalyst/expressions/CodeGenerationSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index f8bc5c560e154..481b335d15dfd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -26,10 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ */ class CodeGenerationSuite extends SparkFunSuite { - test("generate expression") { - GenerateExpression.generate(Add(Literal(1), Literal(1))) - } - test("multithreaded eval") { import scala.concurrent._ import ExecutionContext.Implicits.global From 7c3c8643fb6e10c0e92ce9ef915cbd08d5f164e6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 16:42:46 -0700 Subject: [PATCH 36/62] Undo part of a SparkPlanTest change in #7162 that broke my test. --- .../scala/org/apache/spark/sql/execution/SparkPlanTest.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index e72736d9e2691..03650a02d5248 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -301,7 +301,9 @@ object SparkPlanTest { val resolvedPlan = TestSQLContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { case (a, i) => + (a.name, BoundReference(i, a.dataType, a.nullable)) + }.toMap plan.transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, From 0a79d39a181c5ef44e40a7962745b74e4f05bdf3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 17:41:03 -0700 Subject: [PATCH 37/62] Revert "Undo part of a SparkPlanTest change in #7162 that broke my test." This reverts commit 7c3c8643fb6e10c0e92ce9ef915cbd08d5f164e6. --- .../scala/org/apache/spark/sql/execution/SparkPlanTest.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 03650a02d5248..e72736d9e2691 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -301,9 +301,7 @@ object SparkPlanTest { val resolvedPlan = TestSQLContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { case (a, i) => - (a.name, BoundReference(i, a.dataType, a.nullable)) - }.toMap + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan.transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, From f27be09497896d20925b79b804f4b9ea1c084118 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 18:46:59 -0700 Subject: [PATCH 38/62] Fix tests by binding attributes. --- .../org/apache/spark/sql/execution/basicOperators.scala | 5 +++-- .../apache/spark/sql/execution/UnsafeExternalSortSuite.scala | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 05fec25876416..36aba591de188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -270,9 +270,10 @@ case class UnsafeExternalSort( assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(sortOrder.head) + val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(sortOrder.head) + val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) new UnsafeExternalRowSorter.PrefixComputer { override def computePrefix(row: InternalRow): Long = prefixComputer(row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index b9245d57facbb..df3cd130d3962 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ - class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { From 88b72db57afee0fc7ffb5626fa98a7d0739ed0fe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 18:47:35 -0700 Subject: [PATCH 39/62] Test ascending and descending sort orders. --- .../execution/UnsafeExternalRowSorter.java | 1 + .../execution/UnsafeExternalSortSuite.scala | 39 ++++++++++--------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 6974db83e80bc..55c9f4c1b9cc4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -178,6 +178,7 @@ public Iterator sort(Iterator inputIterator) throws IO * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. */ public static boolean supportsSchema(StructType schema) { + // TODO: add spilling note. for (StructField field : schema.fields()) { if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { return false; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index df3cd130d3962..38a23eaa51780 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -37,25 +37,28 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { } // Test sorting on different data types - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach{ dataType => - for (nullable <- Seq(true, false)) { - RandomDataGenerator.forType(dataType, nullable).foreach { randomDataGenerator => - test(s"sorting on $dataType with nullable=$nullable") { - val inputData = Seq.fill(1024)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - checkAnswer( - inputDf, - UnsafeExternalSort('a.asc :: Nil, global = false, _: SparkPlan), - Sort('a.asc :: Nil, global = false, _: SparkPlan) - ) + // TODO: randomized spilling to ensure that merging is tested at least once for every data type. + (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => + for ( + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1024)(randomDataGenerator()).filter { + case d: Double => !d.isNaN + case f: Float => !java.lang.Float.isNaN(f) + case x => true } + val inputDf = TestSQLContext.createDataFrame( + TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + checkAnswer( + inputDf, + UnsafeExternalSort(sortOrder, global = false, _: SparkPlan), + Sort(sortOrder, global = false, _: SparkPlan) + ) } } } From 82e21c13058ac3bed2262ff684dae839b8e2ada9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 20:22:52 -0700 Subject: [PATCH 40/62] Force spilling in UnsafeExternalSortSuite. --- .../execution/UnsafeExternalRowSorter.java | 21 +++++++++++++++++++ .../spark/sql/execution/basicOperators.scala | 11 ++++++++-- .../execution/UnsafeExternalSortSuite.scala | 3 ++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 55c9f4c1b9cc4..2b858dd6ed385 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -44,6 +44,14 @@ final class UnsafeExternalRowSorter { + /** + * If positive, forces records to be spilled to disk at the given frequency (measured in numbers + * of records). This is only intended to be used in tests. + */ + private int testSpillFrequency = 0; + + private long numRowsInserted = 0; + private final StructType schema; private final UnsafeRowConverter rowConverter; private final PrefixComputer prefixComputer; @@ -77,6 +85,15 @@ public UnsafeExternalRowSorter( ); } + /** + * Forces spills to occur every `frequency` records. Only for use in tests. + */ + @VisibleForTesting + void setTestSpillFrequency(int frequency) { + assert frequency > 0 : "Frequency must be positive"; + testSpillFrequency = frequency; + } + @VisibleForTesting void insertRow(InternalRow row) throws IOException { final int sizeRequirement = rowConverter.getSizeRequirement(row); @@ -99,6 +116,10 @@ void insertRow(InternalRow row) throws IOException { sizeRequirement, prefix ); + numRowsInserted++; + if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { + spill(); + } } @VisibleForTesting diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 36aba591de188..ed2a7bdfe2cce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -253,12 +253,15 @@ case class ExternalSort( * * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. */ @DeveloperApi case class UnsafeExternalSort( sortOrder: Seq[SortOrder], global: Boolean, - child: SparkPlan) + child: SparkPlan, + testSpillFrequency: Int = 0) extends UnaryNode { private[this] val schema: StructType = child.schema @@ -278,7 +281,11 @@ case class UnsafeExternalSort( override def computePrefix(row: InternalRow): Long = prefixComputer(row) } } - new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter.sort(iterator) } child.execute().mapPartitions(doSort, preservesPartitioning = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 38a23eaa51780..f92aab904f754 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -54,9 +54,10 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) + assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkAnswer( inputDf, - UnsafeExternalSort(sortOrder, global = false, _: SparkPlan), + UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 100), Sort(sortOrder, global = false, _: SparkPlan) ) } From 8d7fbe7f8759236ab7d554e973797963664590bc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 22:16:37 -0700 Subject: [PATCH 41/62] Fixes to multiple spilling-related bugs. --- .../unsafe/sort/UnsafeExternalSorter.java | 8 ++--- .../unsafe/sort/UnsafeInMemorySorter.java | 10 ++++++- .../unsafe/sort/UnsafeSorterSpillReader.java | 29 +++++++++++-------- .../unsafe/sort/UnsafeSorterSpillWriter.java | 23 ++++++++++++--- .../sort/UnsafeExternalSorterSuite.java | 19 +++++------- 5 files changed, 55 insertions(+), 34 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 51382f52d124b..52205d9fc7dd3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -126,17 +126,15 @@ public void spill() throws IOException { spillWriters.size() > 1 ? " times" : " time"); final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics); + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, + sorter.numRecords()); spillWriters.add(spillWriter); final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); final Object baseObject = sortedRecords.getBaseObject(); final long baseOffset = sortedRecords.getBaseOffset(); - // TODO: this assumption that the first long holds a length is not enforced via our interfaces - // We need to either always store this via the write path (e.g. not require the caller to do - // it), or provide interfaces / hooks for customizing the physical storage format etc. - final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final int recordLength = sortedRecords.getRecordLength(); spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c084290ba6117..833a951d7ec08 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -89,6 +89,13 @@ public UnsafeInMemorySorter( this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); } + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pointerArrayInsertPosition / 2; + } + public long getMemoryUsage() { return pointerArray.length * 8L; } @@ -106,7 +113,8 @@ public void expandPointerArray() { } /** - * Inserts a record to be sorted. + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a 4-byte integer, followed by the record's bytes. * * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. * @param keyPrefix a user-defined key prefix diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index c42d698347e88..9249d96911cf3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -25,16 +25,21 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; +/** + * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description + * of the file format). + */ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { - private final File file; private InputStream in; private DataInputStream din; - private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? - private int nextRecordLength; - + // Variables that change with every record read: + private int recordLength; private long keyPrefix; + private int numRecordsRemaining; + + private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? private final Object baseObject = arr; private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; @@ -42,25 +47,25 @@ public UnsafeSorterSpillReader( BlockManager blockManager, File file, BlockId blockId) throws IOException { - this.file = file; - assert (file.length() > 0); + assert (file.length() > 0); final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); this.in = blockManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); - nextRecordLength = din.readInt(); + numRecordsRemaining = din.readInt(); } @Override public boolean hasNext() { - return (in != null); + return (numRecordsRemaining > 0); } @Override public void loadNext() throws IOException { + recordLength = din.readInt(); keyPrefix = din.readLong(); - ByteStreams.readFully(in, arr, 0, nextRecordLength); - nextRecordLength = din.readInt(); - if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + ByteStreams.readFully(in, arr, 0, recordLength); + numRecordsRemaining--; + if (numRecordsRemaining == 0) { in.close(); in = null; din = null; @@ -79,7 +84,7 @@ public long getBaseOffset() { @Override public int getRecordLength() { - return 0; + return recordLength; } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index ade96396ef650..41d0b079835ce 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -30,10 +30,14 @@ import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; +/** + * Spills a list of sorted records to disk. Spill files have the following format: + * + * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] + */ final class UnsafeSorterSpillWriter { static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - static final int EOF_MARKER = -1; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer @@ -42,22 +46,29 @@ final class UnsafeSorterSpillWriter { private final File file; private final BlockId blockId; + private final int numRecordsToWrite; private BlockObjectWriter writer; + private int numRecordsSpilled = 0; public UnsafeSorterSpillWriter( BlockManager blockManager, int fileBufferSize, - ShuffleWriteMetrics writeMetrics) { + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { final Tuple2 spilledFileInfo = blockManager.diskBlockManager().createTempLocalBlock(); this.file = spilledFileInfo._2(); this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. // Our write path doesn't actually use this serializer (since we end up calling the `write()` // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work // around this, we pass a dummy no-op serializer. writer = blockManager.getDiskWriter( blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); } // Based on DataOutputStream.writeLong. @@ -85,6 +96,12 @@ public void write( long baseOffset, int recordLength, long keyPrefix) throws IOException { + if (numRecordsSpilled == numRecordsToWrite) { + throw new IllegalStateException( + "Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite); + } else { + numRecordsSpilled++; + } writeIntToBuffer(recordLength, 0); writeLongToBuffer(keyPrefix, 4); int dataRemaining = recordLength; @@ -107,8 +124,6 @@ public void write( } public void close() throws IOException { - writeIntToBuffer(EOF_MARKER, 0); - writer.write(writeBuffer, 0, 4); writer.commitAndClose(); writer = null; writeBuffer = null; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 0f8d5c38b5637..c1a8e623c7b1b 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -153,22 +153,17 @@ public void testSortingOnlyByPrefix() throws Exception { insertNumber(sorter, 3); sorter.spill(); insertNumber(sorter, 4); + sorter.spill(); insertNumber(sorter, 2); UnsafeSorterIterator iter = sorter.getSortedIterator(); - iter.loadNext(); - assertEquals(1, iter.getKeyPrefix()); - iter.loadNext(); - assertEquals(2, iter.getKeyPrefix()); - iter.loadNext(); - assertEquals(3, iter.getKeyPrefix()); - iter.loadNext(); - assertEquals(4, iter.getKeyPrefix()); - iter.loadNext(); - assertEquals(5, iter.getKeyPrefix()); - assertFalse(iter.hasNext()); - // TODO: check that the values are also read back properly. + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + // TODO: read rest of value. + } // TODO: test for cleanup: // assert(tempDir.isEmpty) From 87b6ed9b7b6e70b018b938a4247135db17a1744d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 23:11:28 -0700 Subject: [PATCH 42/62] Fix critical issues in test which led to false negatives. --- .../spark/sql/execution/SparkPlanTest.scala | 22 +++++++--- .../execution/UnsafeExternalSortSuite.scala | 44 +++++++++---------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index e72736d9e2691..ece9dafbb1270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} @@ -145,12 +144,15 @@ class SparkPlanTest extends SparkFunSuite { * instantiate a reference implementation of the physical operator * that's being tested. The result of executing this plan will be * treated as the source-of-truth for the test. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedPlanFunction: SparkPlan => SparkPlan): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction) match { + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -175,7 +177,8 @@ object SparkPlanTest { def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedPlanFunction: SparkPlan => SparkPlan): Option[String] = { + expectedPlanFunction: SparkPlan => SparkPlan, + sortAnswers: Boolean): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) @@ -210,7 +213,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer).map { errorMessage => + compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -262,7 +265,8 @@ object SparkPlanTest { private def compareAnswers( sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row]): Option[String] = { + expectedAnswer: Seq[Row], + sort: Boolean = true): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to @@ -277,7 +281,11 @@ object SparkPlanTest { case o => o }) } - converted.sortBy(_.toString()) + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } } if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { val errorMessage = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index f92aab904f754..f4b8782e39b03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -38,29 +38,29 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types // TODO: randomized spilling to ensure that merging is tested at least once for every data type. - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => - for ( - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1024)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) - checkAnswer( - inputDf, - UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 100), - Sort(sortOrder, global = false, _: SparkPlan) - ) + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(3)(randomDataGenerator()).filter { + case d: Double => !d.isNaN + case f: Float => !java.lang.Float.isNaN(f) + case x => true } + val inputDf = TestSQLContext.createDataFrame( + TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + checkAnswer( + inputDf, + UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 2), + Sort(sortOrder, global = false, _: SparkPlan), + sortAnswers = false + ) } } } From 5d6109d00833a3534ecb175ac5508316476737cc Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 23:26:11 -0700 Subject: [PATCH 43/62] Fix inconsistent handling / encoding of record lengths. --- .../unsafe/sort/UnsafeInMemorySorter.java | 1 + .../unsafe/sort/UnsafeSorterSpillWriter.java | 12 ++++++++++-- .../unsafe/sort/UnsafeInMemorySorterSuite.java | 16 +++++++--------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 833a951d7ec08..fc34ad9cff369 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -156,6 +156,7 @@ public boolean hasNext() { @Override public void loadNext() { + // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = sortBuffer[position]; baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 41d0b079835ce..5815c2c487ca3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -91,6 +91,14 @@ private void writeIntToBuffer(int v, int offset) throws IOException { writeBuffer[offset + 3] = (byte)(v >>> 0); } + /** + * Write a record to a spill file. + * + * @param baseObject the base object / memory page containing the record + * @param baseOffset the base offset which points directly to the record data. + * @param recordLength the length of the record. + * @param keyPrefix a sort key prefix + */ public void write( Object baseObject, long baseOffset, @@ -105,8 +113,8 @@ public void write( writeIntToBuffer(recordLength, 0); writeLongToBuffer(keyPrefix, 4); int dataRemaining = recordLength; - int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; - long recordReadPosition = baseOffset + 4; // skip over record length + int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len + long recordReadPosition = baseOffset; while (dataRemaining > 0) { final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining); PlatformDependent.copyMemory( diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 67666e35aaeb9..909500930539c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -34,14 +34,13 @@ public class UnsafeInMemorySorterSuite { - private static String getStringFromDataPage(Object baseObject, long baseOffset) { - final int strLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); - final byte[] strBytes = new byte[strLength]; + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { + final byte[] strBytes = new byte[length]; PlatformDependent.copyMemory( baseObject, - baseOffset + 4, + baseOffset, strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + PlatformDependent.BYTE_ARRAY_OFFSET, length); return new String(strBytes); } @@ -116,7 +115,7 @@ public int compare(long prefix1, long prefix2) { // position now points to the start of a record (which holds its length). final int recordLength = PlatformDependent.UNSAFE.getInt(baseObject, position); final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); - final String str = getStringFromDataPage(baseObject, position); + final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); sorter.insertRecord(address, partitionId); position += 4 + recordLength; @@ -127,9 +126,8 @@ public int compare(long prefix1, long prefix2) { Arrays.sort(dataToSort); while (iter.hasNext()) { iter.loadNext(); - // TODO: the logic for how we manipulate record length offsets here is confusing; clean - // this up and clarify it in comments. - final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset() - 4); + final String str = + getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset(), iter.getRecordLength()); final long keyPrefix = iter.getKeyPrefix(); assertThat(str, isIn(Arrays.asList(dataToSort))); assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix)); From b81a92046bd720b331155917c2106d866038a7c2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 23:29:10 -0700 Subject: [PATCH 44/62] Temporarily enable only the passing sort tests --- .../spark/sql/execution/UnsafeExternalSortSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index f4b8782e39b03..c5ab74ad72266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -40,12 +40,12 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // TODO: randomized spilling to ensure that merging is tested at least once for every data type. for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + nullable <- Seq(false); + sortOrder <- Seq('a.asc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(3)(randomDataGenerator()).filter { + val inputData = Seq.fill(10)(randomDataGenerator()).filter { case d: Double => !d.isNaN case f: Float => !java.lang.Float.isNaN(f) case x => true @@ -57,7 +57,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkAnswer( inputDf, - UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 2), + UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 3), Sort(sortOrder, global = false, _: SparkPlan), sortAnswers = false ) From 1c7bad8c538be233a6308a970861a2ff3fe002c7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jul 2015 23:44:33 -0700 Subject: [PATCH 45/62] Make sorting of answers explicit in SparkPlanTest.checkAnswer(). --- .../spark/sql/execution/SortSuite.scala | 10 ++-- .../spark/sql/execution/SparkPlanTest.scala | 60 ++++++++++++++----- .../execution/UnsafeExternalSortSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 15 +++-- 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index a1e3ca11b1ad9..be59c502e8c64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -33,12 +33,14 @@ class SortSuite extends SparkPlanTest { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), - input.sorted) + ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._1, t._2)), + sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), - input.sortBy(t => (t._2, t._1))) + ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + input.sortBy(t => (t._2, t._1)), + sortAnswers = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ece9dafbb1270..831b0f9109ab1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -46,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer) + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Unit = { + checkAnswer( + input :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans.head), + expectedAnswer, + sortAnswers) } /** @@ -61,14 +68,20 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( left: DataFrame, right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - checkAnswer(left :: right :: Nil, - (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer) + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Unit = { + checkAnswer( + left :: right :: Nil, + (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), + expectedAnswer, + sortAnswers) } /** @@ -77,12 +90,15 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[Row]): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -94,13 +110,16 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer[A <: Product : TypeTag]( input: DataFrame, planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + expectedAnswer: Seq[A], + sortAnswers: Boolean): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows) + checkAnswer(input, planFunction, expectedRows, sortAnswers) } /** @@ -110,14 +129,17 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer[A <: Product : TypeTag]( left: DataFrame, right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + expectedAnswer: Seq[A], + sortAnswers: Boolean): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(left, right, planFunction, expectedRows) + checkAnswer(left, right, planFunction, expectedRows, sortAnswers) } /** @@ -126,13 +148,16 @@ class SparkPlanTest extends SparkFunSuite { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ protected def checkAnswer[A <: Product : TypeTag]( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[A]): Unit = { + expectedAnswer: Seq[A], + sortAnswers: Boolean): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows) + checkAnswer(input, planFunction, expectedRows, sortAnswers) } /** @@ -231,11 +256,14 @@ object SparkPlanTest { * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate * the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. */ def checkAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[Row]): Option[String] = { + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) @@ -254,7 +282,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage => + compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -266,7 +294,7 @@ object SparkPlanTest { private def compareAnswers( sparkAnswer: Seq[Row], expectedAnswer: Seq[Row], - sort: Boolean = true): Option[String] = { + sort: Boolean): Option[String] = { def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index c5ab74ad72266..c697c319980dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -40,7 +40,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // TODO: randomized spilling to ensure that merging is tested at least once for every data type. for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(false); + nullable <- Seq(true, false); sortOrder <- Seq('a.asc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 5707d2fb300ae..f498f8c063e5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -47,7 +47,8 @@ class OuterJoinSuite extends SparkPlanTest { (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ), + sortAnswers = true) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), @@ -55,7 +56,8 @@ class OuterJoinSuite extends SparkPlanTest { (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ), + sortAnswers = true) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), @@ -65,7 +67,8 @@ class OuterJoinSuite extends SparkPlanTest { (3, 3.0, null, null), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ), + sortAnswers = true) } test("broadcast hash outer join") { @@ -75,7 +78,8 @@ class OuterJoinSuite extends SparkPlanTest { (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ), + sortAnswers = true) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), @@ -83,6 +87,7 @@ class OuterJoinSuite extends SparkPlanTest { (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ), + sortAnswers = true) } } From b86e68440cd7fd6d8a93c76ea3f421765c916639 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 00:29:59 -0700 Subject: [PATCH 46/62] Set global = true in UnsafeExternalSortSuite. --- .../apache/spark/sql/execution/UnsafeExternalSortSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index c697c319980dd..e51f1335ec1e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -37,7 +37,6 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { } // Test sorting on different data types - // TODO: randomized spilling to ensure that merging is tested at least once for every data type. for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); nullable <- Seq(true, false); @@ -57,8 +56,8 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkAnswer( inputDf, - UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 3), - Sort(sortOrder, global = false, _: SparkPlan), + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 3), + Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) } From 08701e7fbd9ab9998e652e5040e60c59cdf2ba5e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 01:04:31 -0700 Subject: [PATCH 47/62] Fix prefix comparison of null primitives. --- .../unsafe/sort/PrefixComparators.java | 8 +++++ .../spark/sql/execution/SortPrefixUtils.scala | 30 ++++++++++++++----- .../execution/UnsafeExternalSortSuite.scala | 2 +- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index a57e43ac89f7a..7a6093aaff235 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -75,6 +75,8 @@ public int compare(long aPrefix, long bPrefix) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + public final long NULL_PREFIX = computePrefix(Integer.MIN_VALUE); + public long computePrefix(int value) { return value & 0xffffffffL; } @@ -85,6 +87,8 @@ public static final class LongPrefixComparator extends PrefixComparator { public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + + public final long NULL_PREFIX = Long.MIN_VALUE; } public static final class FloatPrefixComparator extends PrefixComparator { @@ -98,6 +102,8 @@ public int compare(long aPrefix, long bPrefix) { public long computePrefix(float value) { return Float.floatToIntBits(value) & 0xffffffffL; } + + public final long NULL_PREFIX = computePrefix(Float.MIN_VALUE); } public static final class DoublePrefixComparator extends PrefixComparator { @@ -111,5 +117,7 @@ public int compare(long aPrefix, long bPrefix) { public long computePrefix(double value) { return Double.doubleToLongBits(value); } + + public final long NULL_PREFIX = computePrefix(Double.MIN_VALUE); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 8f996a2c7c182..366b09ddfbfdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -48,15 +48,29 @@ object SortPrefixUtils { def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { sortOrder.dataType match { - case StringType => (row: InternalRow) => + case StringType => (row: InternalRow) => { PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - case IntegerType => (row: InternalRow) => - PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) - case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long] - case FloatType => (row: InternalRow) => - PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - case DoubleType => (row: InternalRow) => - PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) + } + case IntegerType => (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGER.NULL_PREFIX + else PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) + } + case LongType => (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.LONG.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Long] + } + case FloatType => (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX + else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) + } + case DoubleType => (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX + else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) + } case _ => (row: InternalRow) => 0L } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index e51f1335ec1e5..5233c73638c85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -38,7 +38,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + dataType <- DataTypeTestUtils.atomicTypes; // Disable null type for now due to bug in SqlSerializer2 ++ Set(NullType); nullable <- Seq(true, false); sortOrder <- Seq('a.asc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) From 1d7ffaa8e49f72f063a3e1d10c261749a18edd04 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 01:12:43 -0700 Subject: [PATCH 48/62] Somewhat hacky fix for descending sorts --- .../spark/sql/execution/basicOperators.scala | 14 +++++++++++++- .../sql/execution/UnsafeExternalSortSuite.scala | 5 +++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ed2a7bdfe2cce..1e7f7129f4b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -274,7 +275,18 @@ case class UnsafeExternalSort( def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + // Hack until we generate separate comparator implementations for ascending vs. descending + // (or choose to codegen them): + val prefixComparator = { + val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) + if (sortOrder.head.direction == Descending) { + new PrefixComparator { + override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) + } + } else { + comp + } + } val prefixComputer = { val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) new UnsafeExternalRowSorter.PrefixComputer { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 5233c73638c85..1dd81ee4e9fb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -38,9 +38,10 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types for ( - dataType <- DataTypeTestUtils.atomicTypes; // Disable null type for now due to bug in SqlSerializer2 ++ Set(NullType); + dataType <- DataTypeTestUtils.atomicTypes // Disable null type for now due to bug in SqlSerializer2 ++ Set(NullType); + if !dataType.isInstanceOf[DecimalType]; // Since we don't have an unsafe representation for decimals nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { From 613e16f91ddba990e5271d5d40ead113aee6390c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 01:35:00 -0700 Subject: [PATCH 49/62] Test with larger data. --- .../apache/spark/sql/execution/UnsafeExternalSortSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 1dd81ee4e9fb6..28729a37457ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -45,7 +45,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(10)(randomDataGenerator()).filter { + val inputData = Seq.fill(1000)(randomDataGenerator()).filter { case d: Double => !d.isNaN case f: Float => !java.lang.Float.isNaN(f) case x => true @@ -57,7 +57,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkAnswer( inputDf, - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 3), + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 88aff1847dc8d896d4b8dac471206ba8729d9b13 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 01:36:44 -0700 Subject: [PATCH 50/62] NULL_PREFIX has to be negative infinity for floating point types --- .../spark/util/collection/unsafe/sort/PrefixComparators.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 7a6093aaff235..5f82579c0267b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -103,7 +103,7 @@ public long computePrefix(float value) { return Float.floatToIntBits(value) & 0xffffffffL; } - public final long NULL_PREFIX = computePrefix(Float.MIN_VALUE); + public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } public static final class DoublePrefixComparator extends PrefixComparator { @@ -118,6 +118,6 @@ public long computePrefix(double value) { return Double.doubleToLongBits(value); } - public final long NULL_PREFIX = computePrefix(Double.MIN_VALUE); + public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } From 9d00afc0f3bb269a2ab31d5e3b48592bf637705b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 15:49:40 -0700 Subject: [PATCH 51/62] Clean up prefix comparators for integral types --- .../unsafe/sort/PrefixComparators.java | 24 +++------- .../unsafe/sort/PrefixComparatorsSuite.java | 45 ------------------- .../spark/sql/execution/SortPrefixUtils.scala | 44 +++++++++++++----- 3 files changed, 38 insertions(+), 75 deletions(-) delete mode 100644 core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5f82579c0267b..d5deda76c84eb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -19,6 +19,7 @@ import com.google.common.base.Charsets; import com.google.common.primitives.Longs; + import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; @@ -27,8 +28,7 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntPrefixComparator INTEGER = new IntPrefixComparator(); - public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); @@ -67,22 +67,10 @@ public long computePrefix(String value) { } } - public static final class IntPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - int a = (int) aPrefix; - int b = (int) bPrefix; - return (a < b) ? -1 : (a > b) ? 1 : 0; - } - - public final long NULL_PREFIX = computePrefix(Integer.MIN_VALUE); - - public long computePrefix(int value) { - return value & 0xffffffffL; - } - } - - public static final class LongPrefixComparator extends PrefixComparator { + /** + * Prefix comparator for all integral types (boolean, byte, short, int, long). + */ + public static final class IntegralPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java deleted file mode 100644 index 4c4a7d5f5486a..0000000000000 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection.unsafe.sort; - -import org.junit.Test; -import static org.junit.Assert.*; - -public class PrefixComparatorsSuite { - - private static int genericComparison(Comparable a, Comparable b) { - return a.compareTo(b); - } - - @Test - public void intPrefixComparator() { - int[] testData = new int[] { 0, Integer.MIN_VALUE, Integer.MAX_VALUE, 0, 1, 2, -1, -2, 1024}; - for (int a : testData) { - for (int b : testData) { - long aPrefix = PrefixComparators.INTEGER.computePrefix(a); - long bPrefix = PrefixComparators.INTEGER.computePrefix(b); - assertEquals( - "Wrong prefix comparison results for a=" + a + " b=" + b, - genericComparison(a, b), - PrefixComparators.INTEGER.compare(aPrefix, bPrefix)); - - } - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 366b09ddfbfdb..2dee3542d6101 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -38,8 +38,7 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { case StringType => PrefixComparators.STRING - case IntegerType => PrefixComparators.INTEGER - case LongType => PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL case FloatType => PrefixComparators.FLOAT case DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator @@ -51,16 +50,37 @@ object SortPrefixUtils { case StringType => (row: InternalRow) => { PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) } - case IntegerType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGER.NULL_PREFIX - else PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) - } - case LongType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.LONG.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } + case BooleanType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 + else 0 + } + case ByteType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Byte] + } + case ShortType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Short] + } + case IntegerType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Int] + } + case LongType => + (row: InternalRow) => { + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Long] + } case FloatType => (row: InternalRow) => { val exprVal = sortOrder.child.eval(row) if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX From f99a612c491431faf405e3367e78a8583f611527 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 16:41:59 -0700 Subject: [PATCH 52/62] Fix bugs in string prefix comparison. --- .../unsafe/sort/PrefixComparators.java | 32 +++++++++---------- .../unsafe/sort/PrefixComparatorsSuite.scala | 31 ++++++++++++++++++ 2 files changed, 46 insertions(+), 17 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index d5deda76c84eb..438742565c51d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -19,6 +19,7 @@ import com.google.common.base.Charsets; import com.google.common.primitives.Longs; +import com.google.common.primitives.UnsignedBytes; import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; @@ -35,36 +36,33 @@ private PrefixComparators() {} public static final class StringPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - // TODO: this can certainly be done more efficiently + // TODO: can done more efficiently byte[] a = Longs.toByteArray(aPrefix); byte[] b = Longs.toByteArray(bPrefix); for (int i = 0; i < 8; i++) { - if (a[i] == b[i]) continue; - if (a[i] > b[i]) return -1; - else if (a[i] < b[i]) return 1; + int c = UnsignedBytes.compare(a[i], b[i]); + if (c != 0) return c; } return 0; } - public long computePrefix(UTF8String value) { - // TODO: this can certainly be done more efficiently - return value == null ? 0L : computePrefix(value.toString()); - } - - public long computePrefix(String value) { - // TODO: this can certainly be done more efficiently - if (value == null || value.length() == 0) { + public long computePrefix(byte[] bytes) { + if (bytes == null) { return 0L; } else { - String first4Chars = value.substring(0, Math.min(3, value.length() - 1)); - byte[] utf16Bytes = first4Chars.getBytes(Charsets.UTF_16); byte[] padded = new byte[8]; - if (utf16Bytes.length < 8) { - System.arraycopy(utf16Bytes, 0, padded, 0, utf16Bytes.length); - } + System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); return Longs.fromByteArray(padded); } } + + public long computePrefix(String value) { + return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + } + + public long computePrefix(UTF8String value) { + return value == null ? 0L : computePrefix(value.getBytes()); + } } /** diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala new file mode 100644 index 0000000000000..05b2b77142085 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -0,0 +1,31 @@ +package org.apache.spark.util.collection.unsafe.sort + +import org.scalatest.prop.PropertyChecks + +import org.apache.spark.SparkFunSuite + +class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { + + test("String prefix comparator") { + + def testPrefixComparison(s1: String, s2: String): Unit = { + val s1Prefix = PrefixComparators.STRING.computePrefix(s1) + val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + assert( + (prefixComparisonResult == 0) || + (prefixComparisonResult < 0 && s1 < s2) || + (prefixComparisonResult > 0 && s1 > s2)) + } + + val regressionTests = Table( + ("s1", "s2"), + ("abc", "世界"), + ("你好", "世界"), + ("你好123", "你好122") + ) + + forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } + } +} From 293f1092914d3b816367fbf684b74f6ad883e40b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 16:50:02 -0700 Subject: [PATCH 53/62] Add missing license header. --- .../unsafe/sort/PrefixComparatorsSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 05b2b77142085..fa5939b081694 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection.unsafe.sort import org.scalatest.prop.PropertyChecks From d31f180da9771847aa65033ae69d0a78cbfa6404 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 7 Jul 2015 23:12:45 -0700 Subject: [PATCH 54/62] Re-enable NullType sorting test now that SPARK-8868 is fixed --- .../apache/spark/sql/execution/UnsafeExternalSortSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 28729a37457ea..33307bf2ae51b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -38,7 +38,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types for ( - dataType <- DataTypeTestUtils.atomicTypes // Disable null type for now due to bug in SqlSerializer2 ++ Set(NullType); + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) if !dataType.isInstanceOf[DecimalType]; // Since we don't have an unsafe representation for decimals nullable <- Seq(true, false); sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); From c56ec18c710669adc85851168d4486bfeaea38b6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Jul 2015 00:09:13 -0700 Subject: [PATCH 55/62] Clean up final row copying code. --- .../sql/catalyst/expressions/UnsafeRow.java | 10 ++--- .../execution/UnsafeExternalRowSorter.java | 39 +++++++++++-------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 203756bde2af8..30cd8d169441b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import javax.annotation.Nullable; - import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; @@ -57,17 +55,15 @@ */ public final class UnsafeRow extends MutableRow { - /** Hack for if we want to pass around an UnsafeRow which also carries around its backing data */ - @Nullable public byte[] backingArray; private Object baseObject; private long baseOffset; /** A pool to hold non-primitive objects */ private ObjectPool pool; - Object getBaseObject() { return baseObject; } - long getBaseOffset() { return baseOffset; } - ObjectPool getPool() { return pool; } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 2b858dd6ed385..98bc02ebdd2b3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.spark.sql.Row; import scala.collection.Iterator; import scala.math.Ordering; @@ -150,28 +151,34 @@ public boolean hasNext() { return sortedIterator.hasNext(); } + /** + * Called prior to returning this iterator's last row. This copies the row's data into an + * on-heap byte array so that the pointer to the row data will not be dangling after the + * sorter's memory pages are freed. + */ + private void detachRowFromPage(UnsafeRow row, int rowLength) { + final byte[] rowDataCopy = new byte[rowLength]; + PlatformDependent.copyMemory( + row.getBaseObject(), + row.getBaseOffset(), + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + rowLength + ); + row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, row.getPool()); + } + @Override public InternalRow next() { try { sortedIterator.loadNext(); - if (hasNext()) { - row.pointTo( - sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool); - return row; - } else { - final byte[] rowDataCopy = new byte[sortedIterator.getRecordLength()]; - PlatformDependent.copyMemory( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sortedIterator.getRecordLength() - ); - row.backingArray = rowDataCopy; - row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, objPool); + row.pointTo( + sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool); + if (!hasNext()) { + detachRowFromPage(row, sortedIterator.getRecordLength()); cleanupResources(); - return row; } + return row; } catch (IOException e) { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack From 845bea369a69f4d7be4a82af3a7e33beb605c70b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Jul 2015 00:12:12 -0700 Subject: [PATCH 56/62] Remove unnecessary zeroing of row conversion buffer --- .../spark/sql/execution/UnsafeExternalRowSorter.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 98bc02ebdd2b3..6f771a0b03a67 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -18,9 +18,7 @@ package org.apache.spark.sql.execution; import java.io.IOException; -import java.util.Arrays; -import org.apache.spark.sql.Row; import scala.collection.Iterator; import scala.math.Ordering; @@ -100,12 +98,6 @@ void insertRow(InternalRow row) throws IOException { final int sizeRequirement = rowConverter.getSizeRequirement(row); if (sizeRequirement > rowConversionBuffer.length) { rowConversionBuffer = new byte[sizeRequirement]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero - // out the portion of the buffer that we'll actually write to. - Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0); } final int bytesWritten = rowConverter.writeRow( row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool); From d13ac55eaa022ad6b036320ca0ca6c9076fe19a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 8 Jul 2015 00:56:36 -0700 Subject: [PATCH 57/62] Hacky approach to copying of UnsafeRows for sort followed by limit. --- .../UnsafeFixedWidthAggregationMap.java | 12 +++++-- .../sql/catalyst/expressions/UnsafeRow.java | 27 ++++++++++++-- .../execution/UnsafeExternalRowSorter.java | 36 ++++++------------- .../expressions/UnsafeRowConverter.scala | 10 ++++-- .../expressions/UnsafeRowConverterSuite.scala | 36 ++++++++++++------- .../execution/UnsafeExternalSortSuite.scala | 27 ++++++++++++++ 6 files changed, 104 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 1e79f4b2e88e5..79d55b36dab01 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -120,9 +120,11 @@ public UnsafeFixedWidthAggregationMap( this.bufferPool = new ObjectPool(initialCapacity); InternalRow initRow = initProjection.apply(emptyRow); - this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int emptyBufferSize = bufferConverter.getSizeRequirement(initRow); + this.emptyBuffer = new byte[emptyBufferSize]; int writtenLength = bufferConverter.writeRow( - initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize, + bufferPool); assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; // re-use the empty buffer only when there is no object saved in pool. reuseEmptyBuffer = bufferPool.size() == 0; @@ -142,6 +144,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { groupingKey, groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, + groupingKeySize, keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; @@ -157,7 +160,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { // There is some objects referenced by emptyBuffer, so generate a new one InternalRow initRow = initProjection.apply(emptyRow); bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - bufferPool); + groupingKeySize, bufferPool); } loc.putNewKey( groupingKeyConversionScratchSpace, @@ -175,6 +178,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { address.getBaseObject(), address.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return currentBuffer; @@ -214,12 +218,14 @@ public MapEntry next() { keyAddress.getBaseObject(), keyAddress.getBaseOffset(), keyConverter.numFields(), + loc.getKeyLength(), keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), bufferConverter.numFields(), + loc.getValueLength(), bufferPool ); return entry; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 30cd8d169441b..b31fbed038210 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -63,11 +63,15 @@ public final class UnsafeRow extends MutableRow { public Object getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } public ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + /** The size of this row's backing data, in bytes) */ + private int sizeInBytes; + public int length() { return numFields; } /** The width of the null tracking bit set, in bytes */ @@ -95,14 +99,17 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row + * @param sizeInBytes the size of this row's backing data, in bytes * @param pool the object pool to hold arbitrary objects */ - public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { + public void pointTo( + Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; + this.sizeInBytes = sizeInBytes; this.pool = pool; } @@ -338,7 +345,23 @@ public double getDouble(int i) { @Override public InternalRow copy() { - throw new UnsupportedOperationException(); + if (pool != null) { + throw new UnsupportedOperationException( + "Copy is not supported for UnsafeRows that use object pools"); + } else { + UnsafeRow rowCopy = new UnsafeRow(); + final byte[] rowDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + rowCopy.pointTo( + rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null); + return rowCopy; + } } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 6f771a0b03a67..b94601cf6d818 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -54,7 +54,6 @@ final class UnsafeExternalRowSorter { private final StructType schema; private final UnsafeRowConverter rowConverter; private final PrefixComputer prefixComputer; - private final ObjectPool objPool = new ObjectPool(128); private final UnsafeExternalSorter sorter; private byte[] rowConversionBuffer = new byte[1024 * 8]; @@ -77,7 +76,7 @@ public UnsafeExternalRowSorter( sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, - new RowComparator(ordering, schema.length(), objPool), + new RowComparator(ordering, schema.length(), null), prefixComparator, 4096, sparkEnv.conf() @@ -100,7 +99,7 @@ void insertRow(InternalRow row) throws IOException { rowConversionBuffer = new byte[sizeRequirement]; } final int bytesWritten = rowConverter.writeRow( - row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool); + row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null); assert (bytesWritten == sizeRequirement); final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( @@ -143,31 +142,18 @@ public boolean hasNext() { return sortedIterator.hasNext(); } - /** - * Called prior to returning this iterator's last row. This copies the row's data into an - * on-heap byte array so that the pointer to the row data will not be dangling after the - * sorter's memory pages are freed. - */ - private void detachRowFromPage(UnsafeRow row, int rowLength) { - final byte[] rowDataCopy = new byte[rowLength]; - PlatformDependent.copyMemory( - row.getBaseObject(), - row.getBaseOffset(), - rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - rowLength - ); - row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, row.getPool()); - } - @Override public InternalRow next() { try { sortedIterator.loadNext(); row.pointTo( - sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool); + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + numFields, + sortedIterator.getRecordLength(), + null); if (!hasNext()) { - detachRowFromPage(row, sortedIterator.getRecordLength()); + row.copy(); // so that we don't have dangling pointers to freed page cleanupResources(); } return row; @@ -198,7 +184,7 @@ public Iterator sort(Iterator inputIterator) throws IO * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. */ public static boolean supportsSchema(StructType schema) { - // TODO: add spilling note. + // TODO: add spilling note to explain why we do this for now: for (StructField field : schema.fields()) { if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { return false; @@ -222,8 +208,8 @@ public RowComparator(Ordering ordering, int numFields, ObjectPool o @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1, numFields, objPool); - row2.pointTo(baseObj2, baseOff2, numFields, objPool); + row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool); + row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 1f395497a9839..6af5e6200e57b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param row the row to convert * @param baseObject the base object of the destination address * @param baseOffset the base offset of the destination address + * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) + def writeRow( + row: InternalRow, + baseObject: Object, + baseOffset: Long, + rowLengthInBytes: Int, + pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) if (writers.length > 0) { // zero-out the bitset diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 96d4e64ea344a..aa019b408ac61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -44,11 +44,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -73,12 +75,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = converter.writeRow( + row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() val pool = new ObjectPool(10) - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") assert(unsafeRow.get(2) === "World".getBytes) @@ -111,12 +115,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) assert(sizeRequired === 8 + (8 * 3)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) assert(numBytesWritten === sizeRequired) assert(pool.size === 2) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.get(1) === Decimal(1)) assert(unsafeRow.get(2) === Array(2)) @@ -142,11 +148,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + val numBytesWritten = + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.pointTo( + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow @@ -190,12 +198,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, null) assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, null) for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } @@ -233,10 +243,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val pool = new ObjectPool(1) val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, + sizeRequired, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, + sizeRequired, pool) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 33307bf2ae51b..6e32ea99a7fa5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -36,6 +36,33 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } + ignore("sort followed by limit should not leak memory") { + // TODO: this test is going to fail until we implement a proper iterator interface + // with a close() method. + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + checkAnswer( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sort followed by limit") { + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") + try { + checkAnswer( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } finally { + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + + } + } + // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) From cd05866ade3f342deb9c1ed03ad6dcf852841bad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 9 Jul 2015 01:47:55 -0700 Subject: [PATCH 58/62] Fix scalastyle --- .../util/collection/unsafe/sort/PrefixComparatorsSuite.scala | 2 ++ .../apache/spark/sql/execution/UnsafeExternalSortSuite.scala | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index fa5939b081694..dd505dfa7d758 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -35,12 +35,14 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { (prefixComparisonResult > 0 && s1 > s2)) } + // scalastyle:off val regressionTests = Table( ("s1", "s2"), ("abc", "世界"), ("你好", "世界"), ("你好123", "你好122") ) + // scalastyle:on forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 6e32ea99a7fa5..ab7cbe491c449 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -66,7 +66,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) - if !dataType.isInstanceOf[DecimalType]; // Since we don't have an unsafe representation for decimals + if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals nullable <- Seq(true, false); sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) From 2f48777086fbf85765f72aa890fc62e2e162afd6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 9 Jul 2015 19:13:41 -0700 Subject: [PATCH 59/62] Add test and fix bug for sorting empty arrays --- .../unsafe/sort/UnsafeExternalSorter.java | 3 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 6 ++-- .../unsafe/sort/UnsafeSorterSpillWriter.java | 3 ++ .../sort/UnsafeExternalSorterSuite.java | 30 +++++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 52205d9fc7dd3..4d6731ee60af3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -263,11 +263,12 @@ public void insertRecord( public UnsafeSorterIterator getSortedIterator() throws IOException { final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { return inMemoryIterator; } else { final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator); + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpill(spillWriter.getReader(blockManager)); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 2fb41fb2d402f..8272c2a5be0d1 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -27,7 +27,8 @@ final class UnsafeSorterSpillMerger { public UnsafeSorterSpillMerger( final RecordComparator recordComparator, - final PrefixComparator prefixComparator) { + final PrefixComparator prefixComparator, + final int numSpills) { final Comparator comparator = new Comparator() { @Override @@ -43,8 +44,7 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { } } }; - // TODO: the size is often known; incorporate size hints here. - priorityQueue = new PriorityQueue(10, comparator); + priorityQueue = new PriorityQueue(numSpills, comparator); } public void addSpill(UnsafeSorterIterator spillReader) throws IOException { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 5815c2c487ca3..b8d66659804ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -128,6 +128,9 @@ public void write( dataRemaining -= toTransfer; freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE; } + if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) { + writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer)); + } writer.recordWritten(); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index c1a8e623c7b1b..ea8755e21eb68 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -169,4 +169,34 @@ public void testSortingOnlyByPrefix() throws Exception { // assert(tempDir.isEmpty) } + @Test + public void testSortingEmptyArrays() throws Exception { + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + + for (int i = 1; i <= 5; i++) { + iter.loadNext(); + assertEquals(0, iter.getKeyPrefix()); + assertEquals(0, iter.getRecordLength()); + } + } + } From 513520001b25aac5fc2ba3a160fb6964dc0a52a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 9 Jul 2015 19:50:12 -0700 Subject: [PATCH 60/62] Fix spill reading for large rows; add test --- .../unsafe/sort/UnsafeSorterSpillReader.java | 8 ++++++-- .../spark/sql/execution/UnsafeExternalSortSuite.scala | 11 +++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 9249d96911cf3..29e9e0f30f934 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -39,8 +39,8 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private long keyPrefix; private int numRecordsRemaining; - private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? - private final Object baseObject = arr; + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( @@ -63,6 +63,10 @@ public boolean hasNext() { public void loadNext() throws IOException { recordLength = din.readInt(); keyPrefix = din.readLong(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } ByteStreams.readFully(in, arr, 0, recordLength); numRecordsRemaining--; if (numRecordsRemaining == 0) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index ab7cbe491c449..d5224bb03fcc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -63,6 +63,17 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { } } + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkAnswer( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + // Test sorting on different data types for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) From 35dad9f1c3a50e3d3b428f4fe1cabf8cfafb88a5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 9 Jul 2015 23:18:06 -0700 Subject: [PATCH 61/62] Make sortAnswers = false the default in SparkPlanTest --- .../spark/sql/execution/SparkPlanTest.scala | 14 +++++++------- .../sql/execution/joins/OuterJoinSuite.scala | 15 +++++---------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 831b0f9109ab1..6c7d39fe84c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -53,7 +53,7 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row], - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { checkAnswer( input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), @@ -76,7 +76,7 @@ class SparkPlanTest extends SparkFunSuite { right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, expectedAnswer: Seq[Row], - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { checkAnswer( left :: right :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), @@ -97,7 +97,7 @@ class SparkPlanTest extends SparkFunSuite { input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -117,7 +117,7 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[A], - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) checkAnswer(input, planFunction, expectedRows, sortAnswers) } @@ -137,7 +137,7 @@ class SparkPlanTest extends SparkFunSuite { right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, expectedAnswer: Seq[A], - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) checkAnswer(left, right, planFunction, expectedRows, sortAnswers) } @@ -155,7 +155,7 @@ class SparkPlanTest extends SparkFunSuite { input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[A], - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { val expectedRows = expectedAnswer.map(Row.fromTuple) checkAnswer(input, planFunction, expectedRows, sortAnswers) } @@ -176,7 +176,7 @@ class SparkPlanTest extends SparkFunSuite { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean): Unit = { + sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { case Some(errorMessage) => fail(errorMessage) case None => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index f498f8c063e5b..5707d2fb300ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -47,8 +47,7 @@ class OuterJoinSuite extends SparkPlanTest { (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - ), - sortAnswers = true) + )) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), @@ -56,8 +55,7 @@ class OuterJoinSuite extends SparkPlanTest { (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - ), - sortAnswers = true) + )) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), @@ -67,8 +65,7 @@ class OuterJoinSuite extends SparkPlanTest { (3, 3.0, null, null), (null, null, 3, 2.0), (null, null, 4, 1.0) - ), - sortAnswers = true) + )) } test("broadcast hash outer join") { @@ -78,8 +75,7 @@ class OuterJoinSuite extends SparkPlanTest { (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - ), - sortAnswers = true) + )) checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), @@ -87,7 +83,6 @@ class OuterJoinSuite extends SparkPlanTest { (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - ), - sortAnswers = true) + )) } } From 6beb4674999820126861fdac99fb84f8ea5d57ff Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Jul 2015 12:17:18 -0700 Subject: [PATCH 62/62] Remove a bunch of overloaded methods to avoid default args. issue --- .../spark/sql/execution/SortSuite.scala | 5 +- .../spark/sql/execution/SparkPlanTest.scala | 70 ++----------------- .../execution/UnsafeExternalSortSuite.scala | 8 +-- .../sql/execution/joins/OuterJoinSuite.scala | 21 +++--- 4 files changed, 25 insertions(+), 79 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index be59c502e8c64..a2c10fdaf6cdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ class SortSuite extends SparkPlanTest { @@ -34,13 +35,13 @@ class SortSuite extends SparkPlanTest { checkAnswer( input.toDF("a", "b", "c"), ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), - input.sortBy(t => (t._1, t._2)), + input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), - input.sortBy(t => (t._2, t._1)), + input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 6c7d39fe84c31..6a8f394545816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -54,7 +54,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - checkAnswer( + doCheckAnswer( input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer, @@ -71,13 +71,13 @@ class SparkPlanTest extends SparkFunSuite { * @param sortAnswers if true, the answers will be sorted by their toString representations prior * to being compared. */ - protected def checkAnswer( + protected def checkAnswer2( left: DataFrame, right: DataFrame, planFunction: (SparkPlan, SparkPlan) => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - checkAnswer( + doCheckAnswer( left :: right :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer, @@ -87,13 +87,13 @@ class SparkPlanTest extends SparkFunSuite { /** * Runs the plan and makes sure the answer matches the expected result. * @param input the input data to be used. - * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate - * the physical operator that's being tested. + * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to + * instantiate the physical operator that's being tested. * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param sortAnswers if true, the answers will be sorted by their toString representations prior * to being compared. */ - protected def checkAnswer( + protected def doCheckAnswer( input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], @@ -104,62 +104,6 @@ class SparkPlanTest extends SparkFunSuite { } } - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate - * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer[A <: Product : TypeTag]( - input: DataFrame, - planFunction: SparkPlan => SparkPlan, - expectedAnswer: Seq[A], - sortAnswers: Boolean = true): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows, sortAnswers) - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate - * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer[A <: Product : TypeTag]( - left: DataFrame, - right: DataFrame, - planFunction: (SparkPlan, SparkPlan) => SparkPlan, - expectedAnswer: Seq[A], - sortAnswers: Boolean = true): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(left, right, planFunction, expectedRows, sortAnswers) - } - - /** - * Runs the plan and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate - * the physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer[A <: Product : TypeTag]( - input: Seq[DataFrame], - planFunction: Seq[SparkPlan] => SparkPlan, - expectedAnswer: Seq[A], - sortAnswers: Boolean = true): Unit = { - val expectedRows = expectedAnswer.map(Row.fromTuple) - checkAnswer(input, planFunction, expectedRows, sortAnswers) - } - /** * Runs the plan and makes sure the answer matches the result produced by a reference plan. * @param input the input data to be used. @@ -172,7 +116,7 @@ class SparkPlanTest extends SparkFunSuite { * @param sortAnswers if true, the answers will be sorted by their toString representations prior * to being compared. */ - protected def checkAnswer( + protected def checkThatPlansAgree( input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index d5224bb03fcc8..4f4c1f28564cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -40,7 +40,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // TODO: this test is going to fail until we implement a proper iterator interface // with a close() method. TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") - checkAnswer( + checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), @@ -51,7 +51,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { test("sort followed by limit") { TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") try { - checkAnswer( + checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), @@ -66,7 +66,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { test("sorting does not crash for large inputs") { val sortOrder = 'a.asc :: Nil val stringLength = 1024 * 1024 * 2 - checkAnswer( + checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), @@ -93,7 +93,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) - checkAnswer( + checkThatPlansAgree( inputDf, UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), Sort(sortOrder, global = true, _: SparkPlan), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 5707d2fb300ae..2c27da596bc4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan} import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} @@ -41,23 +42,23 @@ class OuterJoinSuite extends SparkPlanTest { val condition = Some(LessThan('b, 'd)) test("shuffled hash outer join") { - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), Seq( (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ).map(Row.fromTuple)) - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), Seq( (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ).map(Row.fromTuple)) - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right), Seq( (1, 2.0, null, null), @@ -65,24 +66,24 @@ class OuterJoinSuite extends SparkPlanTest { (3, 3.0, null, null), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ).map(Row.fromTuple)) } test("broadcast hash outer join") { - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right), Seq( (1, 2.0, null, null), (2, 1.0, 2, 3.0), (3, 3.0, null, null) - )) + ).map(Row.fromTuple)) - checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) => + checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) => BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right), Seq( (2, 1.0, 2, 3.0), (null, null, 3, 2.0), (null, null, 4, 1.0) - )) + ).map(Row.fromTuple)) } }