diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionStatistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionStatistics.java new file mode 100644 index 0000000000000..7416f90ea1f55 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionStatistics.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import java.util.OptionalLong; + +/** + * A mix-in for input partitions whose records are clustered on the same set of partition keys + * (provided via {@link SupportsReportPartitioning}, see below). Data sources can opt-in to + * implement this interface for the partitions they report to Spark, which will use the info + * to decide whether partition grouping should be applied or not. + * + * @see org.apache.spark.sql.connector.read.SupportsReportPartitioning + * @since 4.0.0 + */ +public interface HasPartitionStatistics extends InputPartition { + + /** + * Returns the size in bytes of the partition statistics associated to this partition. + */ + OptionalLong sizeInBytes(); + + /** + * Returns the number of rows in the partition statistics associated to this partition. + */ + OptionalLong numRows(); + + /** + * Returns the count of files in the partition statistics associated to this partition. + */ + OptionalLong filesCount(); +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index c1967f558c171..505a5a6169204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -605,7 +605,7 @@ object InMemoryBaseTable { } class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage - with InputPartition with HasPartitionKey with Serializable { + with InputPartition with HasPartitionKey with HasPartitionStatistics with Serializable { val rows = new mutable.ArrayBuffer[InternalRow]() val deletes = new mutable.ArrayBuffer[Int]() @@ -617,6 +617,9 @@ class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage def keyString(): String = key.toArray.mkString("/") override def partitionKey(): InternalRow = PartitionInternalRow(key.toArray) + override def sizeInBytes(): OptionalLong = OptionalLong.of(100L) + override def numRows(): OptionalLong = OptionalLong.of(rows.size) + override def filesCount(): OptionalLong = OptionalLong.of(100L) def clear(): Unit = rows.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 93f199dfd5854..d89c0a2525fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2928,6 +2928,67 @@ class DataSourceV2SQLSuiteV1Filter } } + test("Check HasPartitionStatistics from InMemoryPartitionTable") { + val t = "testpart.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id string) USING foo PARTITIONED BY (key int)") + val table = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array(), "tbl")) + .asInstanceOf[InMemoryPartitionTable] + + var partSizes = table.data.map(_.sizeInBytes().getAsLong) + var partRowCounts = table.data.map(_.numRows().getAsLong) + var partFiles = table.data.map(_.filesCount().getAsLong) + assert(partSizes.length == 0) + assert(partRowCounts.length == 0) + assert(partFiles.length == 0) + + sql(s"INSERT INTO $t VALUES ('a', 1), ('b', 2), ('c', 3)") + partSizes = table.data.map(_.sizeInBytes().getAsLong) + assert(partSizes.length == 3) + assert(partSizes.toSet == Set(100, 100, 100)) + partRowCounts = table.data.map(_.numRows().getAsLong) + assert(partRowCounts.length == 3) + assert(partRowCounts.toSet == Set(1, 1, 1)) + partFiles = table.data.map(_.filesCount().getAsLong) + assert(partFiles.length == 3) + assert(partFiles.toSet == Set(100, 100, 100)) + + sql(s"ALTER TABLE $t DROP PARTITION (key=3)") + partSizes = table.data.map(_.sizeInBytes().getAsLong) + assert(partSizes.length == 2) + assert(partSizes.toSet == Set(100, 100)) + partRowCounts = table.data.map(_.numRows().getAsLong) + assert(partRowCounts.length == 2) + assert(partRowCounts.toSet == Set(1, 1)) + partFiles = table.data.map(_.filesCount().getAsLong) + assert(partFiles.length == 2) + assert(partFiles.toSet == Set(100, 100)) + + sql(s"ALTER TABLE $t ADD PARTITION (key=4)") + partSizes = table.data.map(_.sizeInBytes().getAsLong) + assert(partSizes.length == 3) + assert(partSizes.toSet == Set(100, 100, 100)) + partRowCounts = table.data.map(_.numRows().getAsLong) + assert(partRowCounts.length == 3) + assert(partRowCounts.toSet == Set(1, 1, 0)) + partFiles = table.data.map(_.filesCount().getAsLong) + assert(partFiles.length == 3) + assert(partFiles.toSet == Set(100, 100, 100)) + + sql(s"INSERT INTO $t VALUES ('c', 3), ('e', 5)") + partSizes = table.data.map(_.sizeInBytes().getAsLong) + assert(partSizes.length == 5) + assert(partSizes.toSet == Set(100, 100, 100, 100, 100)) + partRowCounts = table.data.map(_.numRows().getAsLong) + assert(partRowCounts.length == 5) + assert(partRowCounts.toSet == Set(1, 1, 0, 1, 1)) + partFiles = table.data.map(_.filesCount().getAsLong) + assert(partFiles.length == 5) + assert(partFiles.toSet == Set(100, 100, 100, 100, 100)) + } + } + test("time travel") { sql("use testcat") // The testing in-memory table simply append the version/timestamp to the table name when