Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,14 @@ object SQLConf {
.createWithDefault(
"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider")

val STATE_SCHEMA_CHECK_ENABLED =
buildConf("spark.sql.streaming.stateStore.stateSchemaCheck")
.doc("When true, Spark will validate the state schema against schema on existing state and " +
"fail query if it's incompatible.")
.version("3.1.0")
.booleanConf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.version("3.1.0")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the versioning wasn't even existed but now it's needed. Thanks for the pointer.

.createWithDefault(true)

val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
.internal()
Expand Down Expand Up @@ -3064,6 +3072,8 @@ class SQLConf extends Serializable with Logging {

def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS)

def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: There was a discussion about these helper vars and the agreement was create them only when multiple places used. All other cases inline it.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Mar 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For StateStoreConf, it's not same if it is available or not, in case of reading default value. If we remove it and access via key, we should deal with default value manually (because confs in StateStoreConf doesn't provide default value if it doesn't exist.)


def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)

def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,21 +307,49 @@ object DataType {
* of `fromField.nullable` and `toField.nullable` are false.
*/
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
equalsIgnoreCompatibleNullability(from, to, ignoreName = false)
}

/**
* Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType, and
* also the field name. It compares based on the position.
*
* Compatible nullability is defined as follows:
* - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
* if and only if `to.containsNull` is true, or both of `from.containsNull` and
* `to.containsNull` are false.
* - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
* if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
* `to.valueContainsNull` are false.
* - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
* if and only if for all every pair of fields, `to.nullable` is true, or both
* of `fromField.nullable` and `toField.nullable` are false.
*/
private[sql] def equalsIgnoreNameAndCompatibleNullability(
from: DataType,
to: DataType): Boolean = {
equalsIgnoreCompatibleNullability(from, to, ignoreName = true)
}

private def equalsIgnoreCompatibleNullability(
from: DataType,
to: DataType,
ignoreName: Boolean = false): Boolean = {
(from, to) match {
case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement, ignoreName)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
(tn || !fn) &&
equalsIgnoreCompatibleNullability(fromKey, toKey) &&
equalsIgnoreCompatibleNullability(fromValue, toValue)
equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) &&
equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (fromField, toField) =>
fromField.name == toField.name &&
(ignoreName || fromField.name == toField.name) &&
(toField.nullable || !fromField.nullable) &&
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName)
}

case (fromDataType, toDataType) => fromDataType == toDataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,36 +267,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
}
}

/**
* Parse the log version from the given `text` -- will throw exception when the parsed version
* exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1",
* "v123xyz" etc.)
*/
private[sql] def validateVersion(text: String, maxSupportedVersion: Int): Int = {
if (text.length > 0 && text(0) == 'v') {
val version =
try {
text.substring(1, text.length).toInt
} catch {
case _: NumberFormatException =>
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
if (version > 0) {
if (version > maxSupportedVersion) {
throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " +
s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " +
s"by a newer version of Spark and cannot be read by this version. Please upgrade.")
} else {
return version
}
}
}

// reaching here means we failed to read the correct log version
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
private[sql] def validateVersion(text: String, maxSupportedVersion: Int): Int =
MetadataVersionUtil.validateVersion(text, maxSupportedVersion)
}

object HDFSMetadataLog {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

object MetadataVersionUtil {
/**
* Parse the log version from the given `text` -- will throw exception when the parsed version
* exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1",
* "v123xyz" etc.)
*/
def validateVersion(text: String, maxSupportedVersion: Int): Int = {
if (text.length > 0 && text(0) == 'v') {
val version =
try {
text.substring(1, text.length).toInt
} catch {
case _: NumberFormatException =>
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
if (version > 0) {
if (version > maxSupportedVersion) {
throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " +
s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " +
s"by a newer version of Spark and cannot be read by this version. Please upgrade.")
} else {
return version
}
}
}

// reaching here means we failed to read the correct log version
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.state

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}

case class StateSchemaNotCompatible(message: String) extends Exception(message)

class StateSchemaCompatibilityChecker(
providerId: StateStoreProviderId,
hadoopConf: Configuration) extends Logging {

private val storeCpLocation = providerId.storeId.storeCheckpointLocation()
private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf)
private val schemaFileLocation = schemaFile(storeCpLocation)

fm.mkdirs(schemaFileLocation.getParent)

def check(keySchema: StructType, valueSchema: StructType): Unit = {
if (fm.exists(schemaFileLocation)) {
logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.")
val (storedKeySchema, storedValueSchema) = readSchemaFile()
if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) {
// schema is exactly same
} else if (!schemasCompatible(storedKeySchema, keySchema) ||
!schemasCompatible(storedValueSchema, valueSchema)) {
val errorMsg = "Provided schema doesn't match to the schema for existing state! " +
"Please note that Spark allow difference of field name: check count of fields " +
"and data type of each field.\n" +
s"- Provided key schema: $keySchema\n" +
s"- Provided value schema: $valueSchema\n" +
s"- Existing key schema: $storedKeySchema\n" +
s"- Existing value schema: $storedValueSchema\n" +
s"If you want to force running query without schema validation, please set " +
s"${SQLConf.STATE_SCHEMA_CHECK_ENABLED.key} to false.\n" +
"Please note running query with incompatible schema could cause indeterministic" +
" behavior."
logError(errorMsg)
throw StateSchemaNotCompatible(errorMsg)
} else {
logInfo("Detected schema change which is compatible. Allowing to put rows.")
}
} else {
// schema doesn't exist, create one now
logDebug(s"Schema file for provider $providerId doesn't exist. Creating one.")
createSchemaFile(keySchema, valueSchema)
}
}

private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean =
DataType.equalsIgnoreNameAndCompatibleNullability(storedSchema, schema)

private def readSchemaFile(): (StructType, StructType) = {
val inStream = fm.open(schemaFileLocation)
try {
val versionStr = inStream.readUTF()
// Currently we only support version 1, which we can simplify the version validation and
// the parse logic.
val version = MetadataVersionUtil.validateVersion(versionStr,
StateSchemaCompatibilityChecker.VERSION)
require(version == 1)

val keySchemaStr = inStream.readUTF()
val valueSchemaStr = inStream.readUTF()

(StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr))
} catch {
case e: Throwable =>
logError(s"Fail to read schema file from $schemaFileLocation", e)
throw e
} finally {
inStream.close()
}
}

private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = {
val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false)
try {
outStream.writeUTF(s"v${StateSchemaCompatibilityChecker.VERSION}")
outStream.writeUTF(keySchema.json)
outStream.writeUTF(valueSchema.json)
outStream.close()
} catch {
case e: Throwable =>
logError(s"Fail to write schema file to $schemaFileLocation", e)
outStream.cancel()
throw e
}
}

private def schemaFile(storeCpLocation: Path): Path =
new Path(new Path(storeCpLocation, "_metadata"), "schema")
}

object StateSchemaCompatibilityChecker {
val VERSION = 1
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.util.Try
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -280,14 +281,14 @@ object StateStoreProvider {
* Return a instance of the required provider, initialized with the given configurations.
*/
def createAndInit(
stateStoreId: StateStoreId,
providerId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
indexOrdinal: Option[Int], // for sorting the data
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStoreProvider = {
val provider = create(storeConf.providerClass)
provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
provider.init(providerId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
provider
}

Expand Down Expand Up @@ -386,10 +387,14 @@ object StateStore extends Logging {

val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval"
val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
val PARTITION_ID_TO_CHECK_SCHEMA = 0

@GuardedBy("loadedProviders")
private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]()

@GuardedBy("loadedProviders")
private val schemaValidated = new mutable.HashMap[StateStoreProviderId, Option[Throwable]]()

/**
* Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
* will be called when an exception happens.
Expand Down Expand Up @@ -467,10 +472,29 @@ object StateStore extends Logging {
hadoopConf: Configuration): StateStoreProvider = {
loadedProviders.synchronized {
startMaintenanceIfNeeded()

if (storeProviderId.storeId.partitionId == PARTITION_ID_TO_CHECK_SCHEMA) {
val result = schemaValidated.getOrElseUpdate(storeProviderId, {
val checker = new StateSchemaCompatibilityChecker(storeProviderId, hadoopConf)
// regardless of configuration, we check compatibility to at least write schema file
// if necessary
val ret = Try(checker.check(keySchema, valueSchema)).toEither.fold(Some(_), _ => None)
if (storeConf.stateSchemaCheckEnabled) {
ret
} else {
None
}
})

if (result.isDefined) {
throw result.get
}
}

val provider = loadedProviders.getOrElseUpdate(
storeProviderId,
StateStoreProvider.createAndInit(
storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
storeProviderId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
)
reportActiveStoreInstance(storeProviderId)
provider
Expand All @@ -482,6 +506,12 @@ object StateStore extends Logging {
loadedProviders.remove(storeProviderId).foreach(_.close())
}

/** Unload all state store providers: unit test purpose */
private[sql] def unloadAll(): Unit = loadedProviders.synchronized {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, if we eagerly check the stateSchemaCheckEnabled config, then this test specific function also can be removed and we can use the config to control the behavior in the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here as well.

loadedProviders.keySet.foreach { key => unload(key) }
loadedProviders.clear()
}

/** Whether a state store provider is loaded or not */
def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized {
loadedProviders.contains(storeProviderId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class StateStoreConf(
/** The compression codec used to compress delta and snapshot files. */
val compressionCodec: String = sqlConf.stateStoreCompressionCodec

/** whether to validate state schema during query run. */
val stateSchemaCheckEnabled = sqlConf.isStateSchemaCheckEnabled

/**
* Additional configurations related to state store. This will capture all configs in
* SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific
Expand Down
Loading