2121import java .nio .ByteBuffer ;
2222import java .util .List ;
2323
24+ import org .apache .commons .lang .NotImplementedException ;
2425import org .apache .hadoop .mapreduce .InputSplit ;
2526import org .apache .hadoop .mapreduce .TaskAttemptContext ;
2627import org .apache .parquet .Preconditions ;
4142import org .apache .spark .sql .catalyst .expressions .codegen .UnsafeRowWriter ;
4243import org .apache .spark .sql .execution .vectorized .ColumnVector ;
4344import org .apache .spark .sql .execution .vectorized .ColumnarBatch ;
45+ import org .apache .spark .sql .types .DataTypes ;
4446import org .apache .spark .sql .types .Decimal ;
4547import org .apache .spark .unsafe .Platform ;
4648import org .apache .spark .unsafe .types .UTF8String ;
@@ -207,13 +209,7 @@ public boolean nextBatch() throws IOException {
207209
208210 int num = (int )Math .min ((long ) columnarBatch .capacity (), totalRowCount - rowsReturned );
209211 for (int i = 0 ; i < columnReaders .length ; ++i ) {
210- switch (columnReaders [i ].descriptor .getType ()) {
211- case INT32 :
212- columnReaders [i ].readIntBatch (num , columnarBatch .column (i ));
213- break ;
214- default :
215- throw new IOException ("Unsupported type: " + columnReaders [i ].descriptor .getType ());
216- }
212+ columnReaders [i ].readBatch (num , columnarBatch .column (i ));
217213 }
218214 rowsReturned += num ;
219215 columnarBatch .setNumRows (num );
@@ -237,7 +233,8 @@ private void initializeInternal() throws IOException {
237233
238234 // TODO: Be extremely cautious in what is supported. Expand this.
239235 if (originalTypes [i ] != null && originalTypes [i ] != OriginalType .DECIMAL &&
240- originalTypes [i ] != OriginalType .UTF8 && originalTypes [i ] != OriginalType .DATE ) {
236+ originalTypes [i ] != OriginalType .UTF8 && originalTypes [i ] != OriginalType .DATE &&
237+ originalTypes [i ] != OriginalType .INT_8 && originalTypes [i ] != OriginalType .INT_16 ) {
241238 throw new IOException ("Unsupported type: " + t );
242239 }
243240 if (originalTypes [i ] == OriginalType .DECIMAL &&
@@ -464,6 +461,11 @@ private final class ColumnReader {
464461 */
465462 private boolean useDictionary ;
466463
464+ /**
465+ * If useDictionary is true, the staging vector used to decode the ids.
466+ */
467+ private ColumnVector dictionaryIds ;
468+
467469 /**
468470 * Maximum definition level for this column.
469471 */
@@ -587,9 +589,8 @@ private boolean next() throws IOException {
587589
588590 /**
589591 * Reads `total` values from this columnReader into column.
590- * TODO: implement the other encodings.
591592 */
592- private void readIntBatch (int total , ColumnVector column ) throws IOException {
593+ private void readBatch (int total , ColumnVector column ) throws IOException {
593594 int rowId = 0 ;
594595 while (total > 0 ) {
595596 // Compute the number of values we want to read in this page.
@@ -599,21 +600,134 @@ private void readIntBatch(int total, ColumnVector column) throws IOException {
599600 leftInPage = (int )(endOfPageValueCount - valuesRead );
600601 }
601602 int num = Math .min (total , leftInPage );
602- defColumn .readIntegers (
603- num , column , rowId , maxDefLevel , (VectorizedValuesReader )dataColumn , 0 );
604-
605- // Remap the values if it is dictionary encoded.
606603 if (useDictionary ) {
607- for (int i = rowId ; i < rowId + num ; ++i ) {
608- column .putInt (i , dictionary .decodeToInt (column .getInt (i )));
604+ // Data is dictionary encoded. We will vector decode the ids and then resolve the values.
605+ if (dictionaryIds == null ) {
606+ dictionaryIds = ColumnVector .allocate (total , DataTypes .IntegerType , MemoryMode .ON_HEAP );
607+ } else {
608+ dictionaryIds .reset ();
609+ dictionaryIds .reserve (total );
610+ }
611+ // Read and decode dictionary ids.
612+ readIntBatch (rowId , num , dictionaryIds );
613+ decodeDictionaryIds (rowId , num , column );
614+ } else {
615+ switch (descriptor .getType ()) {
616+ case INT32 :
617+ readIntBatch (rowId , num , column );
618+ break ;
619+ case INT64 :
620+ readLongBatch (rowId , num , column );
621+ break ;
622+ case BINARY :
623+ readBinaryBatch (rowId , num , column );
624+ break ;
625+ default :
626+ throw new IOException ("Unsupported type: " + descriptor .getType ());
609627 }
610628 }
629+
611630 valuesRead += num ;
612631 rowId += num ;
613632 total -= num ;
614633 }
615634 }
616635
636+ /**
637+ * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
638+ */
639+ private void decodeDictionaryIds (int rowId , int num , ColumnVector column ) {
640+ switch (descriptor .getType ()) {
641+ case INT32 :
642+ if (column .dataType () == DataTypes .IntegerType ) {
643+ for (int i = rowId ; i < rowId + num ; ++i ) {
644+ column .putInt (i , dictionary .decodeToInt (dictionaryIds .getInt (i )));
645+ }
646+ } else if (column .dataType () == DataTypes .ByteType ) {
647+ for (int i = rowId ; i < rowId + num ; ++i ) {
648+ column .putByte (i , (byte )dictionary .decodeToInt (dictionaryIds .getInt (i )));
649+ }
650+ } else {
651+ throw new NotImplementedException ("Unimplemented type: " + column .dataType ());
652+ }
653+ break ;
654+
655+ case INT64 :
656+ for (int i = rowId ; i < rowId + num ; ++i ) {
657+ column .putLong (i , dictionary .decodeToLong (dictionaryIds .getInt (i )));
658+ }
659+ break ;
660+
661+ case BINARY :
662+ // TODO: this is incredibly inefficient as it blows up the dictionary right here. We
663+ // need to do this better. We should probably add the dictionary data to the ColumnVector
664+ // and reuse it across batches. This should mean adding a ByteArray would just update
665+ // the length and offset.
666+ for (int i = rowId ; i < rowId + num ; ++i ) {
667+ Binary v = dictionary .decodeToBinary (dictionaryIds .getInt (i ));
668+ column .putByteArray (i , v .getBytes ());
669+ }
670+ break ;
671+
672+ default :
673+ throw new NotImplementedException ("Unsupported type: " + descriptor .getType ());
674+ }
675+
676+ if (dictionaryIds .numNulls () > 0 ) {
677+ // Copy the NULLs over.
678+ // TODO: we can improve this by decoding the NULLs directly into column. This would
679+ // mean we decode the int ids into `dictionaryIds` and the NULLs into `column` and then
680+ // just do the ID remapping as above.
681+ for (int i = 0 ; i < num ; ++i ) {
682+ if (dictionaryIds .getIsNull (rowId + i )) {
683+ column .putNull (rowId + i );
684+ }
685+ }
686+ }
687+ }
688+
689+ /**
690+ * For all the read*Batch functions, reads `num` values from this columnReader into column. It
691+ * is guaranteed that num is smaller than the number of values left in the current page.
692+ */
693+
694+ private void readIntBatch (int rowId , int num , ColumnVector column ) throws IOException {
695+ // This is where we implement support for the valid type conversions.
696+ // TODO: implement remaining type conversions
697+ if (column .dataType () == DataTypes .IntegerType ) {
698+ defColumn .readIntegers (
699+ num , column , rowId , maxDefLevel , (VectorizedValuesReader ) dataColumn , 0 );
700+ } else if (column .dataType () == DataTypes .ByteType ) {
701+ defColumn .readBytes (
702+ num , column , rowId , maxDefLevel , (VectorizedValuesReader ) dataColumn );
703+ } else {
704+ throw new NotImplementedException ("Unimplemented type: " + column .dataType ());
705+ }
706+ }
707+
708+ private void readLongBatch (int rowId , int num , ColumnVector column ) throws IOException {
709+ // This is where we implement support for the valid type conversions.
710+ // TODO: implement remaining type conversions
711+ if (column .dataType () == DataTypes .LongType ) {
712+ defColumn .readLongs (
713+ num , column , rowId , maxDefLevel , (VectorizedValuesReader ) dataColumn );
714+ } else {
715+ throw new NotImplementedException ("Unimplemented type: " + column .dataType ());
716+ }
717+ }
718+
719+ private void readBinaryBatch (int rowId , int num , ColumnVector column ) throws IOException {
720+ // This is where we implement support for the valid type conversions.
721+ // TODO: implement remaining type conversions
722+ if (column .isArray ()) {
723+ defColumn .readBinarys (
724+ num , column , rowId , maxDefLevel , (VectorizedValuesReader ) dataColumn );
725+ } else {
726+ throw new NotImplementedException ("Unimplemented type: " + column .dataType ());
727+ }
728+ }
729+
730+
617731 private void readPage () throws IOException {
618732 DataPage page = pageReader .readPage ();
619733 // TODO: Why is this a visitor?
0 commit comments