Skip to content

Commit 5eb89f6

Browse files
committed
[SPARK-9577][SQL] Surface concrete iterator types in various sort classes.
We often return abstract iterator types in various sort-related classes (e.g. UnsafeKVExternalSorter). It is actually better to return a more concrete type, so the callsite uses that type and JIT can inline the iterator calls. Author: Reynold Xin <[email protected]> Closes #7911 from rxin/surface-concrete-type and squashes the following commits: 0422add [Reynold Xin] [SPARK-9577][SQL] Surface concrete iterator types in various sort classes.
1 parent 3b0e444 commit 5eb89f6

File tree

4 files changed

+65
-85
lines changed

4 files changed

+65
-85
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ public void insertKVRecord(
428428

429429
public UnsafeSorterIterator getSortedIterator() throws IOException {
430430
assert(inMemSorter != null);
431-
final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
431+
final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
432432
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
433433
if (spillWriters.isEmpty()) {
434434
return inMemoryIterator;

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ public void insertRecord(long recordPointer, long keyPrefix) {
133133
pointerArrayInsertPosition++;
134134
}
135135

136-
private static final class SortedIterator extends UnsafeSorterIterator {
136+
public static final class SortedIterator extends UnsafeSorterIterator {
137137

138138
private final TaskMemoryManager memoryManager;
139139
private final int sortBufferInsertPosition;
@@ -144,7 +144,7 @@ private static final class SortedIterator extends UnsafeSorterIterator {
144144
private long keyPrefix;
145145
private int recordLength;
146146

147-
SortedIterator(
147+
private SortedIterator(
148148
TaskMemoryManager memoryManager,
149149
int sortBufferInsertPosition,
150150
long[] sortBuffer) {
@@ -186,7 +186,7 @@ public void loadNext() {
186186
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
187187
* {@code next()} will return the same mutable object.
188188
*/
189-
public UnsafeSorterIterator getSortedIterator() {
189+
public SortedIterator getSortedIterator() {
190190
sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
191191
return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
192192
}

sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -134,66 +134,15 @@ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
134134
value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
135135
}
136136

137-
public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
137+
public KVSorterIterator sortedIterator() throws IOException {
138138
try {
139139
final UnsafeSorterIterator underlying = sorter.getSortedIterator();
140140
if (!underlying.hasNext()) {
141141
// Since we won't ever call next() on an empty iterator, we need to clean up resources
142142
// here in order to prevent memory leaks.
143143
cleanupResources();
144144
}
145-
146-
return new KVIterator<UnsafeRow, UnsafeRow>() {
147-
private UnsafeRow key = new UnsafeRow();
148-
private UnsafeRow value = new UnsafeRow();
149-
private int numKeyFields = keySchema.size();
150-
private int numValueFields = valueSchema.size();
151-
152-
@Override
153-
public boolean next() throws IOException {
154-
try {
155-
if (underlying.hasNext()) {
156-
underlying.loadNext();
157-
158-
Object baseObj = underlying.getBaseObject();
159-
long recordOffset = underlying.getBaseOffset();
160-
int recordLen = underlying.getRecordLength();
161-
162-
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
163-
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
164-
int valueLen = recordLen - keyLen - 4;
165-
166-
key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
167-
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
168-
169-
return true;
170-
} else {
171-
key = null;
172-
value = null;
173-
cleanupResources();
174-
return false;
175-
}
176-
} catch (IOException e) {
177-
cleanupResources();
178-
throw e;
179-
}
180-
}
181-
182-
@Override
183-
public UnsafeRow getKey() {
184-
return key;
185-
}
186-
187-
@Override
188-
public UnsafeRow getValue() {
189-
return value;
190-
}
191-
192-
@Override
193-
public void close() {
194-
cleanupResources();
195-
}
196-
};
145+
return new KVSorterIterator(underlying);
197146
} catch (IOException e) {
198147
cleanupResources();
199148
throw e;
@@ -233,4 +182,61 @@ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff
233182
return ordering.compare(row1, row2);
234183
}
235184
}
185+
186+
public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> {
187+
private UnsafeRow key = new UnsafeRow();
188+
private UnsafeRow value = new UnsafeRow();
189+
private final int numKeyFields = keySchema.size();
190+
private final int numValueFields = valueSchema.size();
191+
private final UnsafeSorterIterator underlying;
192+
193+
private KVSorterIterator(UnsafeSorterIterator underlying) {
194+
this.underlying = underlying;
195+
}
196+
197+
@Override
198+
public boolean next() throws IOException {
199+
try {
200+
if (underlying.hasNext()) {
201+
underlying.loadNext();
202+
203+
Object baseObj = underlying.getBaseObject();
204+
long recordOffset = underlying.getBaseOffset();
205+
int recordLen = underlying.getRecordLength();
206+
207+
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
208+
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
209+
int valueLen = recordLen - keyLen - 4;
210+
211+
key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
212+
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
213+
214+
return true;
215+
} else {
216+
key = null;
217+
value = null;
218+
cleanupResources();
219+
return false;
220+
}
221+
} catch (IOException e) {
222+
cleanupResources();
223+
throw e;
224+
}
225+
}
226+
227+
@Override
228+
public UnsafeRow getKey() {
229+
return key;
230+
}
231+
232+
@Override
233+
public UnsafeRow getValue() {
234+
return value;
235+
}
236+
237+
@Override
238+
public void close() {
239+
cleanupResources();
240+
}
241+
};
236242
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.sql.execution.aggregate
1919

20-
import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
2120
import org.apache.spark.unsafe.KVIterator
2221
import org.apache.spark.{SparkEnv, TaskContext}
2322
import org.apache.spark.sql.catalyst.InternalRow
2423
import org.apache.spark.sql.catalyst.expressions._
2524
import org.apache.spark.sql.catalyst.expressions.aggregate._
25+
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
2626
import org.apache.spark.sql.types.StructType
2727

2828
/**
@@ -230,7 +230,7 @@ class UnsafeHybridAggregationIterator(
230230
}
231231

232232
// Step 5: Get the sorted iterator from the externalSorter.
233-
val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
233+
val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()
234234

235235
// Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
236236
// For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
@@ -368,31 +368,5 @@ object UnsafeHybridAggregationIterator {
368368
newMutableProjection,
369369
outputsUnsafeRows)
370370
}
371-
372-
def createFromKVIterator(
373-
groupingKeyAttributes: Seq[Attribute],
374-
valueAttributes: Seq[Attribute],
375-
inputKVIterator: KVIterator[UnsafeRow, InternalRow],
376-
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
377-
nonCompleteAggregateAttributes: Seq[Attribute],
378-
completeAggregateExpressions: Seq[AggregateExpression2],
379-
completeAggregateAttributes: Seq[Attribute],
380-
initialInputBufferOffset: Int,
381-
resultExpressions: Seq[NamedExpression],
382-
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
383-
outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
384-
new UnsafeHybridAggregationIterator(
385-
groupingKeyAttributes,
386-
valueAttributes,
387-
inputKVIterator,
388-
nonCompleteAggregateExpressions,
389-
nonCompleteAggregateAttributes,
390-
completeAggregateExpressions,
391-
completeAggregateAttributes,
392-
initialInputBufferOffset,
393-
resultExpressions,
394-
newMutableProjection,
395-
outputsUnsafeRows)
396-
}
397371
// scalastyle:on
398372
}

0 commit comments

Comments
 (0)