Skip to content

Commit a244e20

Browse files
author
Davies Liu
committed
Merge branch 'master' of github.com:apache/spark into gen_bench
2 parents 62eb43d + 99a6e3c commit a244e20

File tree

11 files changed

+604
-62
lines changed

11 files changed

+604
-62
lines changed

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
255255
var memoryThreshold = initialMemoryThreshold
256256
// Memory to request as a multiple of current vector size
257257
val memoryGrowthFactor = 1.5
258-
// Previous unroll memory held by this task, for releasing later (only at the very end)
259-
val previousMemoryReserved = currentUnrollMemoryForThisTask
258+
// Keep track of pending unroll memory reserved by this method.
259+
var pendingMemoryReserved = 0L
260260
// Underlying vector for unrolling the block
261261
var vector = new SizeTrackingVector[Any]
262262

@@ -266,6 +266,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
266266
if (!keepUnrolling) {
267267
logWarning(s"Failed to reserve initial memory threshold of " +
268268
s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
269+
} else {
270+
pendingMemoryReserved += initialMemoryThreshold
269271
}
270272

271273
// Unroll this block safely, checking whether we have exceeded our threshold periodically
@@ -278,6 +280,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
278280
if (currentSize >= memoryThreshold) {
279281
val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
280282
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
283+
if (keepUnrolling) {
284+
pendingMemoryReserved += amountToRequest
285+
}
281286
// New threshold is currentSize * memoryGrowthFactor
282287
memoryThreshold += amountToRequest
283288
}
@@ -304,10 +309,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo
304309
// release the unroll memory yet. Instead, we transfer it to pending unroll memory
305310
// so `tryToPut` can further transfer it to normal storage memory later.
306311
// TODO: we can probably express this without pending unroll memory (SPARK-10907)
307-
val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved
308-
unrollMemoryMap(taskAttemptId) -= amountToTransferToPending
312+
unrollMemoryMap(taskAttemptId) -= pendingMemoryReserved
309313
pendingUnrollMemoryMap(taskAttemptId) =
310-
pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending
314+
pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + pendingMemoryReserved
311315
}
312316
} else {
313317
// Otherwise, if we return an iterator, we can only release the unroll memory when

sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ object RandomDataGenerator {
4747
*/
4848
private val PROBABILITY_OF_NULL: Float = 0.1f
4949

50-
private val MAX_STR_LEN: Int = 1024
51-
private val MAX_ARR_SIZE: Int = 128
52-
private val MAX_MAP_SIZE: Int = 128
50+
final val MAX_STR_LEN: Int = 1024
51+
final val MAX_ARR_SIZE: Int = 128
52+
final val MAX_MAP_SIZE: Int = 128
5353

5454
/**
5555
* Helper function for constructing a biased random number generator which returns "interesting"
@@ -208,7 +208,17 @@ object RandomDataGenerator {
208208
forType(valueType, nullable = valueContainsNull, rand)
209209
) yield {
210210
() => {
211-
Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap
211+
val length = rand.nextInt(MAX_MAP_SIZE)
212+
val keys = scala.collection.mutable.HashSet(Seq.fill(length)(keyGenerator()): _*)
213+
// In case the number of different keys is not enough, set a max iteration to avoid
214+
// infinite loop.
215+
var count = 0
216+
while (keys.size < length && count < MAX_MAP_SIZE) {
217+
keys += keyGenerator()
218+
count += 1
219+
}
220+
val values = Seq.fill(keys.size)(valueGenerator())
221+
keys.zip(values).toMap
212222
}
213223
}
214224
}

sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,15 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
9595
}
9696
}
9797

98+
test("check size of generated map") {
99+
val mapType = MapType(IntegerType, IntegerType)
100+
for (seed <- 1 to 1000) {
101+
val generator = RandomDataGenerator.forType(
102+
mapType, nullable = false, rand = new Random(seed)).get
103+
val maps = Seq.fill(100)(generator().asInstanceOf[Map[Int, Int]])
104+
val expectedTotalElements = 100 / 2 * RandomDataGenerator.MAX_MAP_SIZE
105+
val deviation = math.abs(maps.map(_.size).sum - expectedTotalElements)
106+
assert(deviation.toDouble / expectedTotalElements < 2e-1)
107+
}
108+
}
98109
}

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java

Lines changed: 130 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.nio.ByteBuffer;
2222
import java.util.List;
2323

24+
import org.apache.commons.lang.NotImplementedException;
2425
import org.apache.hadoop.mapreduce.InputSplit;
2526
import org.apache.hadoop.mapreduce.TaskAttemptContext;
2627
import org.apache.parquet.Preconditions;
@@ -41,6 +42,7 @@
4142
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
4243
import org.apache.spark.sql.execution.vectorized.ColumnVector;
4344
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
45+
import org.apache.spark.sql.types.DataTypes;
4446
import org.apache.spark.sql.types.Decimal;
4547
import org.apache.spark.unsafe.Platform;
4648
import org.apache.spark.unsafe.types.UTF8String;
@@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException {
207209

208210
int num = (int)Math.min((long) columnarBatch.capacity(), totalRowCount - rowsReturned);
209211
for (int i = 0; i < columnReaders.length; ++i) {
210-
switch (columnReaders[i].descriptor.getType()) {
211-
case INT32:
212-
columnReaders[i].readIntBatch(num, columnarBatch.column(i));
213-
break;
214-
default:
215-
throw new IOException("Unsupported type: " + columnReaders[i].descriptor.getType());
216-
}
212+
columnReaders[i].readBatch(num, columnarBatch.column(i));
217213
}
218214
rowsReturned += num;
219215
columnarBatch.setNumRows(num);
@@ -237,7 +233,8 @@ private void initializeInternal() throws IOException {
237233

238234
// TODO: Be extremely cautious in what is supported. Expand this.
239235
if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
240-
originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE) {
236+
originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE &&
237+
originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) {
241238
throw new IOException("Unsupported type: " + t);
242239
}
243240
if (originalTypes[i] == OriginalType.DECIMAL &&
@@ -464,6 +461,11 @@ private final class ColumnReader {
464461
*/
465462
private boolean useDictionary;
466463

464+
/**
465+
* If useDictionary is true, the staging vector used to decode the ids.
466+
*/
467+
private ColumnVector dictionaryIds;
468+
467469
/**
468470
* Maximum definition level for this column.
469471
*/
@@ -587,9 +589,8 @@ private boolean next() throws IOException {
587589

588590
/**
589591
* Reads `total` values from this columnReader into column.
590-
* TODO: implement the other encodings.
591592
*/
592-
private void readIntBatch(int total, ColumnVector column) throws IOException {
593+
private void readBatch(int total, ColumnVector column) throws IOException {
593594
int rowId = 0;
594595
while (total > 0) {
595596
// Compute the number of values we want to read in this page.
@@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException {
599600
leftInPage = (int)(endOfPageValueCount - valuesRead);
600601
}
601602
int num = Math.min(total, leftInPage);
602-
defColumn.readIntegers(
603-
num, column, rowId, maxDefLevel, (VectorizedValuesReader)dataColumn, 0);
604-
605-
// Remap the values if it is dictionary encoded.
606603
if (useDictionary) {
607-
for (int i = rowId; i < rowId + num; ++i) {
608-
column.putInt(i, dictionary.decodeToInt(column.getInt(i)));
604+
// Data is dictionary encoded. We will vector decode the ids and then resolve the values.
605+
if (dictionaryIds == null) {
606+
dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
607+
} else {
608+
dictionaryIds.reset();
609+
dictionaryIds.reserve(total);
610+
}
611+
// Read and decode dictionary ids.
612+
readIntBatch(rowId, num, dictionaryIds);
613+
decodeDictionaryIds(rowId, num, column);
614+
} else {
615+
switch (descriptor.getType()) {
616+
case INT32:
617+
readIntBatch(rowId, num, column);
618+
break;
619+
case INT64:
620+
readLongBatch(rowId, num, column);
621+
break;
622+
case BINARY:
623+
readBinaryBatch(rowId, num, column);
624+
break;
625+
default:
626+
throw new IOException("Unsupported type: " + descriptor.getType());
609627
}
610628
}
629+
611630
valuesRead += num;
612631
rowId += num;
613632
total -= num;
614633
}
615634
}
616635

636+
/**
637+
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
638+
*/
639+
private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
640+
switch (descriptor.getType()) {
641+
case INT32:
642+
if (column.dataType() == DataTypes.IntegerType) {
643+
for (int i = rowId; i < rowId + num; ++i) {
644+
column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
645+
}
646+
} else if (column.dataType() == DataTypes.ByteType) {
647+
for (int i = rowId; i < rowId + num; ++i) {
648+
column.putByte(i, (byte)dictionary.decodeToInt(dictionaryIds.getInt(i)));
649+
}
650+
} else {
651+
throw new NotImplementedException("Unimplemented type: " + column.dataType());
652+
}
653+
break;
654+
655+
case INT64:
656+
for (int i = rowId; i < rowId + num; ++i) {
657+
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
658+
}
659+
break;
660+
661+
case BINARY:
662+
// TODO: this is incredibly inefficient as it blows up the dictionary right here. We
663+
// need to do this better. We should probably add the dictionary data to the ColumnVector
664+
// and reuse it across batches. This should mean adding a ByteArray would just update
665+
// the length and offset.
666+
for (int i = rowId; i < rowId + num; ++i) {
667+
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
668+
column.putByteArray(i, v.getBytes());
669+
}
670+
break;
671+
672+
default:
673+
throw new NotImplementedException("Unsupported type: " + descriptor.getType());
674+
}
675+
676+
if (dictionaryIds.numNulls() > 0) {
677+
// Copy the NULLs over.
678+
// TODO: we can improve this by decoding the NULLs directly into column. This would
679+
// mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then
680+
// just do the ID remapping as above.
681+
for (int i = 0; i < num; ++i) {
682+
if (dictionaryIds.getIsNull(rowId + i)) {
683+
column.putNull(rowId + i);
684+
}
685+
}
686+
}
687+
}
688+
689+
/**
690+
* For all the read*Batch functions, reads `num` values from this columnReader into column. It
691+
* is guaranteed that num is smaller than the number of values left in the current page.
692+
*/
693+
694+
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
695+
// This is where we implement support for the valid type conversions.
696+
// TODO: implement remaining type conversions
697+
if (column.dataType() == DataTypes.IntegerType) {
698+
defColumn.readIntegers(
699+
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn, 0);
700+
} else if (column.dataType() == DataTypes.ByteType) {
701+
defColumn.readBytes(
702+
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
703+
} else {
704+
throw new NotImplementedException("Unimplemented type: " + column.dataType());
705+
}
706+
}
707+
708+
private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
709+
// This is where we implement support for the valid type conversions.
710+
// TODO: implement remaining type conversions
711+
if (column.dataType() == DataTypes.LongType) {
712+
defColumn.readLongs(
713+
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
714+
} else {
715+
throw new NotImplementedException("Unimplemented type: " + column.dataType());
716+
}
717+
}
718+
719+
private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
720+
// This is where we implement support for the valid type conversions.
721+
// TODO: implement remaining type conversions
722+
if (column.isArray()) {
723+
defColumn.readBinarys(
724+
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
725+
} else {
726+
throw new NotImplementedException("Unimplemented type: " + column.dataType());
727+
}
728+
}
729+
730+
617731
private void readPage() throws IOException {
618732
DataPage page = pageReader.readPage();
619733
// TODO: Why is this a visitor?

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818

1919
import java.io.IOException;
2020

21+
import org.apache.spark.sql.Column;
2122
import org.apache.spark.sql.execution.vectorized.ColumnVector;
2223
import org.apache.spark.unsafe.Platform;
2324

25+
import org.apache.commons.lang.NotImplementedException;
2426
import org.apache.parquet.column.values.ValuesReader;
27+
import org.apache.parquet.io.api.Binary;
2528

2629
/**
2730
* An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
@@ -52,15 +55,53 @@ public void skip(int n) {
5255
}
5356

5457
@Override
55-
public void readIntegers(int total, ColumnVector c, int rowId) {
58+
public final void readIntegers(int total, ColumnVector c, int rowId) {
5659
c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
5760
offset += 4 * total;
5861
}
5962

6063
@Override
61-
public int readInteger() {
64+
public final void readLongs(int total, ColumnVector c, int rowId) {
65+
c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
66+
offset += 8 * total;
67+
}
68+
69+
@Override
70+
public final void readBytes(int total, ColumnVector c, int rowId) {
71+
for (int i = 0; i < total; i++) {
72+
// Bytes are stored as a 4-byte little endian int. Just read the first byte.
73+
// TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
74+
c.putInt(rowId + i, buffer[offset]);
75+
offset += 4;
76+
}
77+
}
78+
79+
@Override
80+
public final int readInteger() {
6281
int v = Platform.getInt(buffer, offset);
6382
offset += 4;
6483
return v;
6584
}
85+
86+
@Override
87+
public final long readLong() {
88+
long v = Platform.getLong(buffer, offset);
89+
offset += 8;
90+
return v;
91+
}
92+
93+
@Override
94+
public final byte readByte() {
95+
return (byte)readInteger();
96+
}
97+
98+
@Override
99+
public final void readBinary(int total, ColumnVector v, int rowId) {
100+
for (int i = 0; i < total; i++) {
101+
int len = readInteger();
102+
int start = offset;
103+
offset += len;
104+
v.putByteArray(rowId + i, buffer, start - Platform.BYTE_ARRAY_OFFSET, len);
105+
}
106+
}
66107
}

0 commit comments

Comments
 (0)