diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 3768f7a1824f1..5e02bc3b0de49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import scala.util.hashing.MurmurHash3 + import java.util.{Map => JavaMap} /** @@ -35,6 +37,22 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte override def toString: String = { s"keys: $keyArray, values: $valueArray" } + + override def equals(obj: Any): Boolean = { + obj match { + case other: ArrayBasedMapData => keyArray == other.keyArray && valueArray == other.valueArray + case _ => false + } + } + + // Hash this class as a Product of two hashCodes. We don't know the DataType which prevents us + // from getting individual rows for hashing as a Map. + override def hashCode(): Int = { + val seed = MurmurHash3.productSeed + val keyHash = scala.util.hashing.MurmurHash3.mix(seed, keyArray.hashCode()) + val valueHash = scala.util.hashing.MurmurHash3.mix(keyHash, valueArray.hashCode()) + scala.util.hashing.MurmurHash3.finalizeHash(valueHash, 2) + } } object ArrayBasedMapData { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala index 94e8824cd18cc..fd34c9f1e657c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.types.DataType /** - * This is an internal data representation for map type in Spark SQL. This should not implement - * `equals` and `hashCode` because the type cannot be used as join keys, grouping keys, or - * in equality tests. See SPARK-9415 and PR#13847 for the discussions. + * This is an internal data representation for map type in Spark SQL. This type cannot be used as + * join keys, grouping keys, or in equality tests. See SPARK-9415 and PR#13847 for the discussions. */ abstract class MapData extends Serializable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index 6e07cd5d6415d..8176984e99366 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -142,4 +142,34 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2)) } } + + test("SPARK-40315: simple equal() and hashCode() semantics") { + val dataToAdd: Map[Int, Int] = Map(0 -> -7, 1 -> 3, 10 -> 4, 20 -> 5) + val builder1 = new ArrayBasedMapBuilder(IntegerType, IntegerType) + val builder2 = new ArrayBasedMapBuilder(IntegerType, IntegerType) + val builder3 = new ArrayBasedMapBuilder(IntegerType, IntegerType) + dataToAdd.foreach { case (key, value) => + builder1.put(key, value) + builder2.put(key, value) + // Replace the value by something slightly different in builder3 for one of the keys. + if (key == 20) { + builder3.put(key, value - 1) + } else { + builder3.put(key, value) + } + } + val arrayBasedMapData1 = builder1.build() + val arrayBasedMapData2 = builder2.build() + val arrayBasedMapData3 = builder3.build() + + // We expect two objects to be equal and to have the same hashCode if they have the same + // elements. + assert(arrayBasedMapData1.equals(arrayBasedMapData2)) + assert(arrayBasedMapData1.hashCode() == arrayBasedMapData2.hashCode()) + + // If two objects have different elements, we expect them not to be equal and their hashCode + // to be different. + assert(!arrayBasedMapData1.equals(arrayBasedMapData3)) + assert(arrayBasedMapData1.hashCode() != arrayBasedMapData3.hashCode()) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala index f921f06537080..c6fb2ea14a021 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -39,8 +39,6 @@ class ComplexDataSuite extends SparkFunSuite { val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) - assert(testArrayMap1 !== testArrayMap3) - assert(testArrayMap2 !== testArrayMap4) // UnsafeMapData val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))