diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 1a8f444042c2..20026dba8ff9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} @@ -37,6 +38,7 @@ import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityC import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration /** * An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source. @@ -46,6 +48,8 @@ class StateDataSource extends TableProvider with DataSourceRegister { private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf() + private lazy val serializedHadoopConf = new SerializableConfiguration(hadoopConf) + override def shortName(): String = "statestore" override def getTable( @@ -54,7 +58,17 @@ class StateDataSource extends TableProvider with DataSourceRegister { properties: util.Map[String, String]): Table = { val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - new StateTable(session, schema, sourceOptions, stateConf) + // Read the operator metadata once to see if we can find the information for prefix scan + // encoder used in session window aggregation queries. + val allStateStoreMetadata = new StateMetadataPartitionReader( + sourceOptions.stateCheckpointLocation.getParent.toString, serializedHadoopConf) + .stateMetadata.toArray + val stateStoreMetadata = allStateStoreMetadata.filter { entry => + entry.operatorId == sourceOptions.operatorId && + entry.stateStoreName == sourceOptions.storeName + } + + new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index bbfe3a3f373e..f6d3b3a06b23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.types.StructType @@ -33,11 +33,12 @@ import org.apache.spark.util.SerializableConfiguration class StatePartitionReaderFactory( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, - schema: StructType) extends PartitionReaderFactory { + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { new StatePartitionReader(storeConf, hadoopConf, - partition.asInstanceOf[StateStoreInputPartition], schema) + partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata) } } @@ -49,7 +50,9 @@ class StatePartitionReader( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, - schema: StructType) extends PartitionReader[InternalRow] with Logging { + schema: StructType, + stateStoreMetadata: Array[StateMetadataTableEntry]) + extends PartitionReader[InternalRow] with Logging { private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] @@ -58,13 +61,6 @@ class StatePartitionReader( val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - val allStateStoreMetadata = new StateMetadataPartitionReader( - partition.sourceOptions.stateCheckpointLocation.getParent.toString, hadoopConf) - .stateMetadata.toArray - val stateStoreMetadata = allStateStoreMetadata.filter { entry => - entry.operatorId == partition.sourceOptions.operatorId && - entry.stateStoreName == partition.sourceOptions.storeName - } val numColsPrefixKey = if (stateStoreMetadata.isEmpty) { logWarning("Metadata for state store not found, possible cause is this checkpoint " + "is created by older version of spark. If the query has session window aggregation, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 0d69bf708e94..3c0370b025d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.StateStoreConf import org.apache.spark.sql.types.StructType @@ -35,8 +36,10 @@ class StateScanBuilder( session: SparkSession, schema: StructType, sourceOptions: StateSourceOptions, - stateStoreConf: StateStoreConf) extends ScanBuilder { - override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf) + stateStoreConf: StateStoreConf, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder { + override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf, + stateStoreMetadata) } /** An implementation of [[InputPartition]] for State Store data source. */ @@ -50,7 +53,8 @@ class StateScan( session: SparkSession, schema: StructType, sourceOptions: StateSourceOptions, - stateStoreConf: StateStoreConf) extends Scan with Batch { + stateStoreConf: StateStoreConf, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with Batch { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val hadoopConfBroadcast = session.sparkContext.broadcast( @@ -62,7 +66,8 @@ class StateScan( val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value) val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() { override def accept(path: Path): Boolean = { - fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 + fs.getFileStatus(path).isDirectory && + Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 } }) @@ -105,7 +110,8 @@ class StateScan( hadoopConfBroadcast.value, userFacingSchema, stateSchema) case JoinSideValues.none => - new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema) + new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema, + stateStoreMetadata) } override def toBatch: Batch = this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 824968e709ba..151350a4a044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state.StateStoreConf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -35,7 +36,8 @@ class StateTable( session: SparkSession, override val schema: StructType, sourceOptions: StateSourceOptions, - stateConf: StateStoreConf) + stateConf: StateStoreConf, + stateStoreMetadata: Array[StateMetadataTableEntry]) extends Table with SupportsRead with SupportsMetadataColumns { import StateTable._ @@ -64,7 +66,7 @@ class StateTable( override def capabilities(): util.Set[TableCapability] = CAPABILITY override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = - new StateScanBuilder(session, schema, sourceOptions, stateConf) + new StateScanBuilder(session, schema, sourceOptions, stateConf, stateStoreMetadata) override def properties(): util.Map[String, String] = Map.empty[String, String].asJava