diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 07cd41b06de21..27f3625255d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1294,6 +1294,14 @@ object SQLConf { .createWithDefault( "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") + val STATE_SCHEMA_CHECK_ENABLED = + buildConf("spark.sql.streaming.stateStore.stateSchemaCheck") + .doc("When true, Spark will validate the state schema against schema on existing state and " + + "fail query if it's incompatible.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") .internal() @@ -3064,6 +3072,8 @@ class SQLConf extends Serializable with Logging { def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) + def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e4ee6eb377a4d..9e820f0796a96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -307,21 +307,49 @@ object DataType { * of `fromField.nullable` and `toField.nullable` are false. */ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + equalsIgnoreCompatibleNullability(from, to, ignoreName = false) + } + + /** + * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType, and + * also the field name. It compares based on the position. + * + * Compatible nullability is defined as follows: + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` + * if and only if `to.containsNull` is true, or both of `from.containsNull` and + * `to.containsNull` are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` + * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` + * if and only if for all every pair of fields, `to.nullable` is true, or both + * of `fromField.nullable` and `toField.nullable` are false. + */ + private[sql] def equalsIgnoreNameAndCompatibleNullability( + from: DataType, + to: DataType): Boolean = { + equalsIgnoreCompatibleNullability(from, to, ignoreName = true) + } + + private def equalsIgnoreCompatibleNullability( + from: DataType, + to: DataType, + ignoreName: Boolean = false): Boolean = { (from, to) match { case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => - (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement, ignoreName) case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey) && - equalsIgnoreCompatibleNullability(fromValue, toValue) + equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) && + equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && + (ignoreName || fromField.name == toField.name) && (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 893639a86c88c..b87a5b49eb6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -267,36 +267,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - /** - * Parse the log version from the given `text` -- will throw exception when the parsed version - * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", - * "v123xyz" etc.) - */ - private[sql] def validateVersion(text: String, maxSupportedVersion: Int): Int = { - if (text.length > 0 && text(0) == 'v') { - val version = - try { - text.substring(1, text.length).toInt - } catch { - case _: NumberFormatException => - throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + - s"version from $text.") - } - if (version > 0) { - if (version > maxSupportedVersion) { - throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + - s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + - s"by a newer version of Spark and cannot be read by this version. Please upgrade.") - } else { - return version - } - } - } - - // reaching here means we failed to read the correct log version - throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + - s"version from $text.") - } + private[sql] def validateVersion(text: String, maxSupportedVersion: Int): Int = + MetadataVersionUtil.validateVersion(text, maxSupportedVersion) } object HDFSMetadataLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala new file mode 100644 index 0000000000000..548f2aa5d5c5b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +object MetadataVersionUtil { + /** + * Parse the log version from the given `text` -- will throw exception when the parsed version + * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", + * "v123xyz" etc.) + */ + def validateVersion(text: String, maxSupportedVersion: Int): Int = { + if (text.length > 0 && text(0) == 'v') { + val version = + try { + text.substring(1, text.length).toInt + } catch { + case _: NumberFormatException => + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } + if (version > 0) { + if (version > maxSupportedVersion) { + throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + + s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + + s"by a newer version of Spark and cannot be read by this version. Please upgrade.") + } else { + return version + } + } + } + + // reaching here means we failed to read the correct log version + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala new file mode 100644 index 0000000000000..4ac12c089c0d3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StructType} + +case class StateSchemaNotCompatible(message: String) extends Exception(message) + +class StateSchemaCompatibilityChecker( + providerId: StateStoreProviderId, + hadoopConf: Configuration) extends Logging { + + private val storeCpLocation = providerId.storeId.storeCheckpointLocation() + private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) + private val schemaFileLocation = schemaFile(storeCpLocation) + + fm.mkdirs(schemaFileLocation.getParent) + + def check(keySchema: StructType, valueSchema: StructType): Unit = { + if (fm.exists(schemaFileLocation)) { + logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.") + val (storedKeySchema, storedValueSchema) = readSchemaFile() + if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) { + // schema is exactly same + } else if (!schemasCompatible(storedKeySchema, keySchema) || + !schemasCompatible(storedValueSchema, valueSchema)) { + val errorMsg = "Provided schema doesn't match to the schema for existing state! " + + "Please note that Spark allow difference of field name: check count of fields " + + "and data type of each field.\n" + + s"- Provided key schema: $keySchema\n" + + s"- Provided value schema: $valueSchema\n" + + s"- Existing key schema: $storedKeySchema\n" + + s"- Existing value schema: $storedValueSchema\n" + + s"If you want to force running query without schema validation, please set " + + s"${SQLConf.STATE_SCHEMA_CHECK_ENABLED.key} to false.\n" + + "Please note running query with incompatible schema could cause indeterministic" + + " behavior." + logError(errorMsg) + throw StateSchemaNotCompatible(errorMsg) + } else { + logInfo("Detected schema change which is compatible. Allowing to put rows.") + } + } else { + // schema doesn't exist, create one now + logDebug(s"Schema file for provider $providerId doesn't exist. Creating one.") + createSchemaFile(keySchema, valueSchema) + } + } + + private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = + DataType.equalsIgnoreNameAndCompatibleNullability(storedSchema, schema) + + private def readSchemaFile(): (StructType, StructType) = { + val inStream = fm.open(schemaFileLocation) + try { + val versionStr = inStream.readUTF() + // Currently we only support version 1, which we can simplify the version validation and + // the parse logic. + val version = MetadataVersionUtil.validateVersion(versionStr, + StateSchemaCompatibilityChecker.VERSION) + require(version == 1) + + val keySchemaStr = inStream.readUTF() + val valueSchemaStr = inStream.readUTF() + + (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) + } catch { + case e: Throwable => + logError(s"Fail to read schema file from $schemaFileLocation", e) + throw e + } finally { + inStream.close() + } + } + + private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { + val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) + try { + outStream.writeUTF(s"v${StateSchemaCompatibilityChecker.VERSION}") + outStream.writeUTF(keySchema.json) + outStream.writeUTF(valueSchema.json) + outStream.close() + } catch { + case e: Throwable => + logError(s"Fail to write schema file to $schemaFileLocation", e) + outStream.cancel() + throw e + } + } + + private def schemaFile(storeCpLocation: Path): Path = + new Path(new Path(storeCpLocation, "_metadata"), "schema") +} + +object StateSchemaCompatibilityChecker { + val VERSION = 1 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 05bcee7b05c6f..ab67c19783ff7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -280,14 +281,14 @@ object StateStoreProvider { * Return a instance of the required provider, initialized with the given configurations. */ def createAndInit( - stateStoreId: StateStoreId, + providerId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { val provider = create(storeConf.providerClass) - provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + provider.init(providerId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider } @@ -386,10 +387,14 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 + val PARTITION_ID_TO_CHECK_SCHEMA = 0 @GuardedBy("loadedProviders") private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() + @GuardedBy("loadedProviders") + private val schemaValidated = new mutable.HashMap[StateStoreProviderId, Option[Throwable]]() + /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` * will be called when an exception happens. @@ -467,10 +472,29 @@ object StateStore extends Logging { hadoopConf: Configuration): StateStoreProvider = { loadedProviders.synchronized { startMaintenanceIfNeeded() + + if (storeProviderId.storeId.partitionId == PARTITION_ID_TO_CHECK_SCHEMA) { + val result = schemaValidated.getOrElseUpdate(storeProviderId, { + val checker = new StateSchemaCompatibilityChecker(storeProviderId, hadoopConf) + // regardless of configuration, we check compatibility to at least write schema file + // if necessary + val ret = Try(checker.check(keySchema, valueSchema)).toEither.fold(Some(_), _ => None) + if (storeConf.stateSchemaCheckEnabled) { + ret + } else { + None + } + }) + + if (result.isDefined) { + throw result.get + } + } + val provider = loadedProviders.getOrElseUpdate( storeProviderId, StateStoreProvider.createAndInit( - storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) reportActiveStoreInstance(storeProviderId) provider @@ -482,6 +506,12 @@ object StateStore extends Logging { loadedProviders.remove(storeProviderId).foreach(_.close()) } + /** Unload all state store providers: unit test purpose */ + private[sql] def unloadAll(): Unit = loadedProviders.synchronized { + loadedProviders.keySet.foreach { key => unload(key) } + loadedProviders.clear() + } + /** Whether a state store provider is loaded or not */ def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { loadedProviders.contains(storeProviderId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 11043bc81ae3f..23cb3be32c85a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -55,6 +55,9 @@ class StateStoreConf( /** The compression codec used to compress delta and snapshot files. */ val compressionCodec: String = sqlConf.stateStoreCompressionCodec + /** whether to validate state schema during query run. */ + val stateSchemaCheckEnabled = sqlConf.isStateSchemaCheckEnabled + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala new file mode 100644 index 0000000000000..4eb7603b316aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper.newDir +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { + + private val hadoopConf: Configuration = new Configuration() + private val opId = Random.nextInt(100000) + private val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + + private val structSchema = new StructType() + .add(StructField("nested1", IntegerType, nullable = true)) + .add(StructField("nested2", StringType, nullable = true)) + + private val keySchema = new StructType() + .add(StructField("key1", IntegerType, nullable = true)) + .add(StructField("key2", StringType, nullable = true)) + .add(StructField("key3", structSchema, nullable = true)) + + private val valueSchema = new StructType() + .add(StructField("value1", IntegerType, nullable = true)) + .add(StructField("value2", StringType, nullable = true)) + .add(StructField("value3", structSchema, nullable = true)) + + test("adding field to key should fail") { + val fieldAddedKeySchema = keySchema.add(StructField("newKey", IntegerType)) + verifyException(keySchema, valueSchema, fieldAddedKeySchema, valueSchema) + } + + test("adding field to value should fail") { + val fieldAddedValueSchema = valueSchema.add(StructField("newValue", IntegerType)) + verifyException(keySchema, valueSchema, keySchema, fieldAddedValueSchema) + } + + test("adding nested field in key should fail") { + val fieldAddedNestedSchema = structSchema.add(StructField("newNested", IntegerType)) + val newKeySchema = applyNewSchemaToNestedFieldInKey(fieldAddedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } + + test("adding nested field in value should fail") { + val fieldAddedNestedSchema = structSchema.add(StructField("newNested", IntegerType)) + val newValueSchema = applyNewSchemaToNestedFieldInValue(fieldAddedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("removing field from key should fail") { + val fieldRemovedKeySchema = StructType(keySchema.dropRight(1)) + verifyException(keySchema, valueSchema, fieldRemovedKeySchema, valueSchema) + } + + test("removing field from value should fail") { + val fieldRemovedValueSchema = StructType(valueSchema.drop(1)) + verifyException(keySchema, valueSchema, keySchema, fieldRemovedValueSchema) + } + + test("removing nested field from key should fail") { + val fieldRemovedNestedSchema = StructType(structSchema.dropRight(1)) + val newKeySchema = applyNewSchemaToNestedFieldInKey(fieldRemovedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } + + test("removing nested field from value should fail") { + val fieldRemovedNestedSchema = StructType(structSchema.drop(1)) + val newValueSchema = applyNewSchemaToNestedFieldInValue(fieldRemovedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("changing the type of field in key should fail") { + val typeChangedKeySchema = StructType(keySchema.map(_.copy(dataType = TimestampType))) + verifyException(keySchema, valueSchema, typeChangedKeySchema, valueSchema) + } + + test("changing the type of field in value should fail") { + val typeChangedValueSchema = StructType(valueSchema.map(_.copy(dataType = TimestampType))) + verifyException(keySchema, valueSchema, keySchema, typeChangedValueSchema) + } + + test("changing the type of nested field in key should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(dataType = TimestampType))) + val newKeySchema = applyNewSchemaToNestedFieldInKey(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } + + test("changing the type of nested field in value should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(dataType = TimestampType))) + val newValueSchema = applyNewSchemaToNestedFieldInValue(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("changing the nullability of nullable to non-nullable in key should fail") { + val nonNullChangedKeySchema = StructType(keySchema.map(_.copy(nullable = false))) + verifyException(keySchema, valueSchema, nonNullChangedKeySchema, valueSchema) + } + + test("changing the nullability of nullable to non-nullable in value should fail") { + val nonNullChangedValueSchema = StructType(valueSchema.map(_.copy(nullable = false))) + verifyException(keySchema, valueSchema, keySchema, nonNullChangedValueSchema) + } + + test("changing the nullability of nullable to nonnullable in nested field in key should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) + val newKeySchema = applyNewSchemaToNestedFieldInKey(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } + + test("changing the nullability of nullable to nonnullable in nested field in value should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) + val newValueSchema = applyNewSchemaToNestedFieldInValue(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("changing the name of field in key should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val fieldNameChangedKeySchema = StructType(keySchema.map(newName)) + verifySuccess(keySchema, valueSchema, fieldNameChangedKeySchema, valueSchema) + } + + test("changing the name of field in value should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val fieldNameChangedValueSchema = StructType(valueSchema.map(newName)) + verifySuccess(keySchema, valueSchema, keySchema, fieldNameChangedValueSchema) + } + + test("changing the name of nested field in key should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val newNestedFieldsSchema = StructType(structSchema.map(newName)) + val fieldNameChangedKeySchema = applyNewSchemaToNestedFieldInKey(newNestedFieldsSchema) + verifySuccess(keySchema, valueSchema, fieldNameChangedKeySchema, valueSchema) + } + + test("changing the name of nested field in value should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val newNestedFieldsSchema = StructType(structSchema.map(newName)) + val fieldNameChangedValueSchema = applyNewSchemaToNestedFieldInValue(newNestedFieldsSchema) + verifySuccess(keySchema, valueSchema, keySchema, fieldNameChangedValueSchema) + } + + private def applyNewSchemaToNestedFieldInKey(newNestedSchema: StructType): StructType = { + applyNewSchemaToNestedField(keySchema, newNestedSchema, "key3") + } + + private def applyNewSchemaToNestedFieldInValue(newNestedSchema: StructType): StructType = { + applyNewSchemaToNestedField(valueSchema, newNestedSchema, "value3") + } + + private def applyNewSchemaToNestedField( + originSchema: StructType, + newNestedSchema: StructType, + fieldName: String): StructType = { + val newFields = originSchema.map { field => + if (field.name == fieldName) { + field.copy(dataType = newNestedSchema) + } else { + field + } + } + StructType(newFields) + } + + private def runSchemaChecker( + dir: String, + queryId: UUID, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + // in fact, Spark doesn't support online state schema change, so need to check + // schema only once for each running of JVM + val providerId = StateStoreProviderId( + StateStoreId(dir, opId, partitionId), queryId) + + new StateSchemaCompatibilityChecker(providerId, hadoopConf) + .check(newKeySchema, newValueSchema) + } + + private def verifyException( + oldKeySchema: StructType, + oldValueSchema: StructType, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + val dir = newDir() + val queryId = UUID.randomUUID() + runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) + + val e = intercept[StateSchemaNotCompatible] { + runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) + } + + e.getMessage.contains("Provided schema doesn't match to the schema for existing state!") + e.getMessage.contains(newKeySchema.json) + e.getMessage.contains(newValueSchema.json) + e.getMessage.contains(oldKeySchema.json) + e.getMessage.contains(oldValueSchema.json) + } + + private def verifySuccess( + oldKeySchema: StructType, + oldValueSchema: StructType, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + val dir = newDir() + val queryId = UUID.randomUUID() + runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) + runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0524e29662014..491b0d8b2c26c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.streaming import java.io.File import java.util.{Locale, TimeZone} +import scala.annotation.tailrec + import org.apache.commons.io.FileUtils import org.scalatest.Assertions @@ -33,7 +35,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemorySink -import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager +import org.apache.spark.sql.execution.streaming.state.{StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ @@ -753,6 +755,89 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { ) } + testQuietlyWithAllStateVersions("changing schema of state when restarting query", + (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false")) { + withTempDir { tempDir => + val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) + + // if we don't have verification phase on state schema, modified query would throw NPE with + // stack trace which end users would not easily understand + + testStream(aggregated, Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 21), + ExpectFailure[SparkException] { e => + val stateSchemaExc = findStateSchemaNotCompatible(e) + assert(stateSchemaExc.isDefined) + val msg = stateSchemaExc.get.getMessage + assert(msg.contains("Provided schema doesn't match to the schema for existing state")) + // other verifications are presented in StateStoreSuite + } + ) + } + } + + testQuietlyWithAllStateVersions("changing schema of state when restarting query -" + + " schema check off", + (SQLConf.STATE_SCHEMA_CHECK_ENABLED.key, "false"), + (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false")) { + withTempDir { tempDir => + val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) + + testStream(aggregated, Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 21), + ExpectFailure[SparkException] { e => + val stateSchemaExc = findStateSchemaNotCompatible(e) + // it would bring other error in runtime, but it shouldn't check schema in any way + assert(stateSchemaExc.isEmpty) + } + ) + } + } + + private def prepareTestForChangingSchemaOfState( + tempDir: File): (MemoryStream[Int], DataFrame) = { + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value % 10 AS id", "value") + .groupBy($"id") + .agg( + sum("value").as("sum_value"), + avg("value").as("avg_value"), + max("value").as("max_value")) + + testStream(aggregated, Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 11), + CheckLastBatch((1L, 12L, 6.0, 11)), + StopStream + ) + + StateStore.unloadAll() + + val inputData2 = MemoryStream[Int] + val aggregated2 = inputData2.toDF() + .selectExpr("value % 10 AS id", "value") + .groupBy($"id") + .agg( + sum("value").as("sum_value"), + avg("value").as("avg_value"), + collect_list("value").as("values")) + + inputData2.addData(1, 11) + + (inputData2, aggregated2) + } + + @tailrec + private def findStateSchemaNotCompatible(exc: Throwable): Option[StateSchemaNotCompatible] = { + exc match { + case e1: StateSchemaNotCompatible => Some(e1) + case e1 if e1.getCause != null => findStateSchemaNotCompatible(e1.getCause) + case _ => None + } + } /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala index 33f6b02acb6dd..1032d6c5b6ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import java.io.File +import scala.annotation.tailrec + import org.apache.commons.io.FileUtils import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{InvalidUnsafeRowException, StateSchemaNotCompatible} import org.apache.spark.sql.functions._ import org.apache.spark.util.Utils @@ -239,11 +242,19 @@ class StreamingStateStoreFormatCompatibilitySuite extends StreamTest { CheckAnswer(Row(0, 20, Seq(0, 2, 4, 6, 8)), Row(1, 25, Seq(1, 3, 5, 7, 9))) */ AddData(inputData, 10 to 19: _*), - ExpectFailure[SparkException](e => { - // Check the exception message to make sure the state store format changing. - assert(e.getCause.getCause.getMessage.contains( - "The streaming query failed by state format invalidation.")) - }) + ExpectFailure[SparkException] { e => + assert(findStateSchemaException(e)) + } ) } + + @tailrec + private def findStateSchemaException(exc: Throwable): Boolean = { + exc match { + case _: StateSchemaNotCompatible => true + case _: InvalidUnsafeRowException => true + case e1 if e1.getCause != null => findStateSchemaException(e1.getCause) + case _ => false + } + } }