From dd9dbf0760f6caa7831ed8f2d04ee1fde2514c5e Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 4 Jan 2016 12:02:51 -0800 Subject: [PATCH] [SPARK-12589][SQL] Fix UnsafeRowParquetRecordReader to properly set the row length. The reader was previously not setting the row length meaning it was wrong if there were variable length columns. This problem does not manifest usually, since the value in the column is correct and projecting the row fixes the issue. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 ++++ .../parquet/UnsafeRowParquetRecordReader.java | 9 +++++++ .../datasources/parquet/ParquetIOSuite.scala | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7492b88c471a4..1a351933a366c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -177,6 +177,10 @@ public void pointTo(byte[] buf, int sizeInBytes) { pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } + public void setTotalSize(int sizeInBytes) { + this.sizeInBytes = sizeInBytes; + } + public void setNotNullAt(int i) { assertIndexIsValid(i); BitSetMethods.unset(baseObject, baseOffset, i); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index a6758bddfa7d0..198bfb6d67aee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -256,6 +256,15 @@ private boolean loadBatch() throws IOException { numBatched = num; batchIdx = 0; } + + // Update the total row lengths if the schema contained variable length. We did not maintain + // this as we populated the columns. + if (containsVarLenFields) { + for (int i = 0; i < numBatched; ++i) { + rows[i].setTotalSize(rowWriters[i].holder().totalSize()); + } + } + return true; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0c5d4887ed799..b0581e8b35510 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,6 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -618,6 +619,29 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { readResourceParquetFile("dec-in-fixed-len.parquet"), sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } + + test("SPARK-12589 copy() on rows returned from reader works for strings") { + withTempPath { dir => + val data = (1, "abc") ::(2, "helloabcde") :: Nil + data.toDF().write.parquet(dir.getCanonicalPath) + var hash1: Int = 0 + var hash2: Int = 0 + (false :: true :: Nil).foreach { v => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> v.toString) { + val df = sqlContext.read.parquet(dir.getCanonicalPath) + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) + if (!v) { + hash1 = unsafeRows(0).hashCode() + hash2 = unsafeRows(1).hashCode() + } else { + assert(hash1 == unsafeRows(0).hashCode()) + assert(hash2 == unsafeRows(1).hashCode()) + } + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)