diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java index 446ea1463309f..380717d2e0e9b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java @@ -106,10 +106,19 @@ Map loadPartitionMetadata(InternalRow ident) throws UnsupportedOperationException; /** - * List the identifiers of all partitions that contains the ident in a table. + * List the identifiers of all partitions that have the ident prefix in a table. * * @param ident a prefix of partition identifier * @return an array of Identifiers for the partitions */ InternalRow[] listPartitionIdentifiers(InternalRow ident); + + /** + * List the identifiers of all partitions that match to the ident by names. + * + * @param names the names of partition values in the identifier. + * @param ident a partition identifier values. + * @return an array of Identifiers for the partitions + */ + InternalRow[] listPartitionByNames(String[] names, InternalRow ident); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala index 23987e909aa70..ba762a58b1e52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.types.StructType @@ -96,4 +97,25 @@ class InMemoryPartitionTable( override protected def addPartitionKey(key: Seq[Any]): Unit = { memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava) } + + override def listPartitionByNames( + names: Array[String], + ident: InternalRow): Array[InternalRow] = { + assert(names.length == ident.numFields, + s"Number of partition names (${names.length}) must be equal to " + + s"the number of partition values (${ident.numFields}).") + val schema = partitionSchema + assert(names.forall(fieldName => schema.fieldNames.contains(fieldName)), + s"Some partition names ${names.mkString("[", ", ", "]")} don't belong to " + + s"the partition schema '${schema.sql}'.") + val indexes = names.map(schema.fieldIndex) + val dataTypes = names.map(schema(_).dataType) + val currentRow = new GenericInternalRow(new Array[Any](names.length)) + memoryTablePartitions.keySet().asScala.filter { key => + for (i <- 0 until names.length) { + currentRow.values(i) = key.get(indexes(i), dataTypes(i)) + } + currentRow == ident + }.toArray + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala index e8e28e3422f27..caf7e91612563 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryTableCatalog} +import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference} import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -140,4 +140,45 @@ class SupportsPartitionManagementSuite extends SparkFunSuite { partTable.dropPartition(partIdent1) assert(partTable.listPartitionIdentifiers(InternalRow.empty).isEmpty) } + + test("listPartitionByNames") { + val partCatalog = new InMemoryPartitionTableCatalog + partCatalog.initialize("test", CaseInsensitiveStringMap.empty()) + val table = partCatalog.createTable( + ident, + new StructType() + .add("col0", IntegerType) + .add("part0", IntegerType) + .add("part1", StringType), + Array(LogicalExpressions.identity(ref("part0")), LogicalExpressions.identity(ref("part1"))), + util.Collections.emptyMap[String, String]) + val partTable = table.asInstanceOf[InMemoryPartitionTable] + + Seq( + InternalRow(0, "abc"), + InternalRow(0, "def"), + InternalRow(1, "abc")).foreach { partIdent => + partTable.createPartition(partIdent, new util.HashMap[String, String]()) + } + + Seq( + (Array("part0", "part1"), InternalRow(0, "abc")) -> Set(InternalRow(0, "abc")), + (Array("part0"), InternalRow(0)) -> Set(InternalRow(0, "abc"), InternalRow(0, "def")), + (Array("part1"), InternalRow("abc")) -> Set(InternalRow(0, "abc"), InternalRow(1, "abc")), + (Array.empty[String], InternalRow.empty) -> + Set(InternalRow(0, "abc"), InternalRow(0, "def"), InternalRow(1, "abc")), + (Array("part0", "part1"), InternalRow(3, "xyz")) -> Set(), + (Array("part1"), InternalRow(3.14f)) -> Set() + ).foreach { case ((names, idents), expected) => + assert(partTable.listPartitionByNames(names, idents).toSet === expected) + } + // Check invalid parameters + Seq( + (Array("part0", "part1"), InternalRow(0)), + (Array("col0", "part1"), InternalRow(0, 1)), + (Array("wrong"), InternalRow("invalid")) + ).foreach { case (names, idents) => + intercept[AssertionError](partTable.listPartitionByNames(names, idents)) + } + } }