diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala index 1c96bdf3afa20..23987e909aa70 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala @@ -92,4 +92,8 @@ class InMemoryPartitionTable( override def partitionExists(ident: InternalRow): Boolean = memoryTablePartitions.containsKey(ident) + + override protected def addPartitionKey(key: Seq[Any]): Unit = { + memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 3b47271a114e2..c93053abc550a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -160,12 +160,15 @@ class InMemoryTable( } } + protected def addPartitionKey(key: Seq[Any]): Unit = {} + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => val key = getKey(row) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + addPartitionKey(key) }) this } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 0057415ff6e1d..89f97fe5be6a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog._ @@ -35,6 +36,7 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.sources.SimpleScanSource import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataSourceV2SQLSuite @@ -2460,6 +2462,25 @@ class DataSourceV2SQLSuite } } + test("SPARK-33505: insert into partitioned table") { + val t = "testpart.ns1.ns2.tbl" + withTable(t) { + sql(s""" + |CREATE TABLE $t (id bigint, city string, data string) + |USING foo + |PARTITIONED BY (id, city)""".stripMargin) + val partTable = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable] + val expectedPartitionIdent = InternalRow.fromSeq(Seq(1, UTF8String.fromString("NY"))) + assert(!partTable.partitionExists(expectedPartitionIdent)) + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'abc'") + assert(partTable.partitionExists(expectedPartitionIdent)) + // Insert into the existing partition must not fail + sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'def'") + assert(partTable.partitionExists(expectedPartitionIdent)) + } + } + private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams")