From 2fd8e360cac3edc64e58effabdd071a35e3235a9 Mon Sep 17 00:00:00 2001 From: mcdull-zhang Date: Sun, 13 Mar 2022 20:25:47 +0800 Subject: [PATCH] serialize numKeys out --- .../org/apache/spark/sql/execution/joins/HashedRelation.scala | 4 +++- .../spark/sql/execution/joins/HashedRelationSuite.scala | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 698e7ed6fc57e..253f16e39d352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -207,7 +207,7 @@ private[execution] class ValueRowWithKeyIndex { * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap. * * It's serialized in the following format: - * [number of keys] + * [number of keys] [number of fields] * [size of key] [size of value] [key bytes] [bytes for value] */ private[joins] class UnsafeHashedRelation( @@ -364,6 +364,7 @@ private[joins] class UnsafeHashedRelation( writeInt: (Int) => Unit, writeLong: (Long) => Unit, writeBuffer: (Array[Byte], Int, Int) => Unit) : Unit = { + writeInt(numKeys) writeInt(numFields) // TODO: move these into BytesToBytesMap writeLong(binaryMap.numKeys()) @@ -397,6 +398,7 @@ private[joins] class UnsafeHashedRelation( readInt: () => Int, readLong: () => Long, readBuffer: (Array[Byte], Int, Int) => Unit): Unit = { + numKeys = readInt() numFields = readInt() resultRow = new UnsafeRow(numFields) val nKeys = readLong() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2462fe31a9b66..6c87178f267c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -93,6 +93,9 @@ class HashedRelationSuite extends SharedSparkSession { assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)).toArray === data2) + // SPARK-38542: UnsafeHashedRelation should serialize numKeys out + assert(hashed2.keys().map(_.copy()).forall(_.numFields == 1)) + val os2 = new ByteArrayOutputStream() val out2 = new ObjectOutputStream(os2) hashed2.writeExternal(out2)