From f0fbc53114053f67527712f9c1bb629131a33665 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 21 Mar 2019 21:14:11 +0900 Subject: [PATCH 1/3] [SPARK-27237][SS] Introduce State schema validation among query restart --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++ .../StateSchemaCompatibilityChecker.scala | 120 +++++++++++++++++ .../streaming/state/StateStore.scala | 46 ++++++- .../streaming/state/StateStoreConf.scala | 3 + .../state/StateStoreCoordinator.scala | 49 ++++++- ...StateSchemaCompatibilityCheckerSuite.scala | 121 ++++++++++++++++++ .../streaming/StreamingAggregationSuite.scala | 84 +++++++++++- 7 files changed, 424 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 07cd41b06de2..f99045ce6528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1294,6 +1294,13 @@ object SQLConf { .createWithDefault( "org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider") + val STATE_SCHEMA_CHECK_ENABLED = + buildConf("spark.sql.streaming.stateStore.stateSchemaCheck") + .doc("When true, Spark will validate the state schema against schema on existing state and " + + "fail query if it's incompatible.") + .booleanConf + .createWithDefault(true) + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot") .internal() @@ -3064,6 +3071,8 @@ class SQLConf extends Serializable with Logging { def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) + def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) + def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala new file mode 100644 index 000000000000..73d56b1e4904 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} + +case class StateSchemaNotCompatible(message: String) extends Exception(message) + +class StateSchemaCompatibilityChecker( + providerId: StateStoreProviderId, + hadoopConf: Configuration) extends Logging { + + private val storeCpLocation = providerId.storeId.storeCheckpointLocation() + private val fm = CheckpointFileManager.create(storeCpLocation, hadoopConf) + private val schemaFileLocation = schemaFile(storeCpLocation) + + fm.mkdirs(schemaFileLocation.getParent) + + def check(keySchema: StructType, valueSchema: StructType): Unit = { + if (fm.exists(schemaFileLocation)) { + logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.") + val (storedKeySchema, storedValueSchema) = readSchemaFile() + + def fieldCompatible(fieldOld: StructField, fieldNew: StructField): Boolean = { + // compatibility for nullable + // - same: OK + // - non-nullable -> nullable: OK + // - nullable -> non-nullable: Not compatible + (fieldOld.dataType == fieldNew.dataType) && + ((fieldOld.nullable == fieldNew.nullable) || + (!fieldOld.nullable && fieldNew.nullable)) + } + + def schemaCompatible(schemaOld: StructType, schemaNew: StructType): Boolean = { + (schemaOld.length == schemaNew.length) && + schemaOld.zip(schemaNew).forall { case (f1, f2) => fieldCompatible(f1, f2) } + } + + val errorMsg = "Provided schema doesn't match to the schema for existing state! " + + "Please note that Spark allow difference of field name: check count of fields " + + "and data type of each field.\n" + + s"- provided schema: key $keySchema value $valueSchema\n" + + s"- existing schema: key $storedKeySchema value $storedValueSchema\n" + + s"If you want to force running query without schema validation, please set " + + s"${SQLConf.STATE_SCHEMA_CHECK_ENABLED.key} to false." + + if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) { + // schema is exactly same + } else if (!schemaCompatible(storedKeySchema, keySchema) || + !schemaCompatible(storedValueSchema, valueSchema)) { + logError(errorMsg) + throw StateSchemaNotCompatible(errorMsg) + } else { + logInfo("Detected schema change which is compatible: will overwrite schema file to new.") + // It tries best-effort to overwrite current schema file. + // the schema validation doesn't break even it fails, though it might miss on detecting + // change which is not a big deal. + createSchemaFile(keySchema, valueSchema) + } + } else { + // schema doesn't exist, create one now + logDebug(s"Schema file for provider $providerId doesn't exist. Creating one.") + createSchemaFile(keySchema, valueSchema) + } + } + + private def readSchemaFile(): (StructType, StructType) = { + val inStream = fm.open(schemaFileLocation) + try { + val keySchemaStr = inStream.readUTF() + val valueSchemaStr = inStream.readUTF() + + (StructType.fromString(keySchemaStr), StructType.fromString(valueSchemaStr)) + } catch { + case e: Throwable => + logError(s"Fail to read schema file from $schemaFileLocation", e) + throw e + } finally { + inStream.close() + } + } + + private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { + val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = true) + try { + outStream.writeUTF(keySchema.json) + outStream.writeUTF(valueSchema.json) + outStream.close() + } catch { + case e: Throwable => + logError(s"Fail to write schema file to $schemaFileLocation", e) + outStream.cancel() + throw e + } + } + + private def schemaFile(storeCpLocation: Path): Path = + new Path(new Path(storeCpLocation, "_metadata"), "schema") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 05bcee7b05c6..e959c58f0323 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -280,14 +280,14 @@ object StateStoreProvider { * Return a instance of the required provider, initialized with the given configurations. */ def createAndInit( - stateStoreId: StateStoreId, + providerId: StateStoreProviderId, keySchema: StructType, valueSchema: StructType, indexOrdinal: Option[Int], // for sorting the data storeConf: StateStoreConf, hadoopConf: Configuration): StateStoreProvider = { val provider = create(storeConf.providerClass) - provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + provider.init(providerId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) provider } @@ -328,6 +328,11 @@ object StateStoreProviderId { stateInfo.checkpointLocation, stateInfo.operatorId, partitionIndex, storeName) StateStoreProviderId(storeId, stateInfo.queryRunId) } + + private[sql] def withNoPartitionInformation( + providerId: StateStoreProviderId): StateStoreProviderId = { + providerId.copy(storeId = providerId.storeId.copy(partitionId = -1)) + } } /** @@ -390,6 +395,9 @@ object StateStore extends Logging { @GuardedBy("loadedProviders") private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() + @GuardedBy("loadedProviders") + private val schemaValidated = new mutable.HashSet[StateStoreProviderId]() + /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` * will be called when an exception happens. @@ -467,10 +475,18 @@ object StateStore extends Logging { hadoopConf: Configuration): StateStoreProvider = { loadedProviders.synchronized { startMaintenanceIfNeeded() + + val newProvIdSchemaCheck = StateStoreProviderId.withNoPartitionInformation(storeProviderId) + if (!schemaValidated.contains(newProvIdSchemaCheck)) { + validateSchema(newProvIdSchemaCheck, keySchema, valueSchema, + storeConf.stateSchemaCheckEnabled) + schemaValidated.add(newProvIdSchemaCheck) + } + val provider = loadedProviders.getOrElseUpdate( storeProviderId, StateStoreProvider.createAndInit( - storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) + storeProviderId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf) ) reportActiveStoreInstance(storeProviderId) provider @@ -482,6 +498,12 @@ object StateStore extends Logging { loadedProviders.remove(storeProviderId).foreach(_.close()) } + /** Unload all state store providers: unit test purpose */ + private[sql] def unloadAll(): Unit = loadedProviders.synchronized { + loadedProviders.keySet.foreach { key => unload(key) } + loadedProviders.clear() + } + /** Whether a state store provider is loaded or not */ def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized { loadedProviders.contains(storeProviderId) @@ -564,6 +586,24 @@ object StateStore extends Logging { } } + private def validateSchema( + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + checkEnabled: Boolean): Unit = { + if (SparkEnv.get != null) { + val validated = coordinatorRef.flatMap( + _.validateSchema(storeProviderId, keySchema, valueSchema, checkEnabled)) + + validated match { + case Some(exc) => + // driver would log the information, so just re-throw here + throw exc + case None => + } + } + } + private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 11043bc81ae3..23cb3be32c85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -55,6 +55,9 @@ class StateStoreConf( /** The compression codec used to compress delta and snapshot files. */ val compressionCodec: String = sqlConf.stateStoreCompressionCodec + /** whether to validate state schema during query run. */ + val stateSchemaCheckEnabled = sqlConf.isStateSchemaCheckEnabled + /** * Additional configurations related to state store. This will capture all configs in * SQLConf that start with `spark.sql.streaming.stateStore.` and extraOptions for a specific diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 2b14d37ee21e..773fae8e9a00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -20,11 +20,14 @@ package org.apache.spark.sql.execution.streaming.state import java.util.UUID import scala.collection.mutable +import scala.util.Try -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ @@ -43,6 +46,12 @@ private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executo private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage +private case class ValidateSchema( + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + checkEnabled: Boolean) extends StateStoreCoordinatorMessage + private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage @@ -59,7 +68,8 @@ object StateStoreCoordinatorRef extends Logging { */ def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { - val coordinator = new StateStoreCoordinator(env.rpcEnv) + + val coordinator = new StateStoreCoordinator(env.conf, env.rpcEnv) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -83,7 +93,6 @@ object StateStoreCoordinatorRef extends Logging { * [[StateStore]]s across all the executors, and get their locations for job scheduling. */ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { - private[sql] def reportActiveInstance( stateStoreProviderId: StateStoreProviderId, host: String, @@ -108,6 +117,16 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } + /** Validate state store operator's schema to see it's compatible with existing schema */ + private[sql] def validateSchema( + storeProviderId: StateStoreProviderId, + keySchema: StructType, + valueSchema: StructType, + checkEnabled: Boolean): Option[Exception] = { + rpcEndpointRef.askSync[Option[Exception]]( + ValidateSchema(storeProviderId, keySchema, valueSchema, checkEnabled)) + } + private[state] def stop(): Unit = { rpcEndpointRef.askSync[Boolean](StopCoordinator) } @@ -118,9 +137,12 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) +private class StateStoreCoordinator(conf: SparkConf, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] + private val schemaValidated = new mutable.HashMap[StateStoreProviderId, Option[Throwable]] + + private lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -150,6 +172,25 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) storeIdsToRemove.mkString(", ")) context.reply(true) + case ValidateSchema(providerId, keySchema, valueSchema, checkEnabled) => + // normalize partition ID to validate only once for one state operator + val newProviderId = StateStoreProviderId.withNoPartitionInformation(providerId) + + val result = schemaValidated.getOrElseUpdate(newProviderId, { + val checker = new StateSchemaCompatibilityChecker(newProviderId, hadoopConf) + + // regardless of configuration, we check compatibility to at least write schema file + // if necessary + val ret = Try(checker.check(keySchema, valueSchema)).toEither.fold(Some(_), _ => None) + if (checkEnabled) { + ret + } else { + None + } + }) + + context.reply(result) + case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala new file mode 100644 index 000000000000..e7a555f4e98e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper.newDir +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { + + testQuietly("changing schema of state when restarting query") { + val opId = Random.nextInt(100000) + val partitionId = -1 + + val hadoopConf: Configuration = new Configuration() + + def runSchemaChecker( + dir: String, + queryId: UUID, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + // in fact, Spark doesn't support online state schema change, so need to check + // schema only once for each running of JVM + val providerId = StateStoreProviderId( + StateStoreId(dir, opId, partitionId), queryId) + + new StateSchemaCompatibilityChecker(providerId, hadoopConf) + .check(newKeySchema, newValueSchema) + } + + def verifyException( + oldKeySchema: StructType, + oldValueSchema: StructType, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + val dir = newDir() + val queryId = UUID.randomUUID() + runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) + + val e = intercept[StateSchemaNotCompatible] { + runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) + } + + e.getMessage.contains("Provided schema doesn't match to the schema for existing state!") + e.getMessage.contains(newKeySchema.json) + e.getMessage.contains(newValueSchema.json) + e.getMessage.contains(oldKeySchema.json) + e.getMessage.contains(oldValueSchema.json) + } + + def verifySuccess( + oldKeySchema: StructType, + oldValueSchema: StructType, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + val dir = newDir() + val queryId = UUID.randomUUID() + runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) + runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) + } + + val keySchema = new StructType() + .add(StructField("key1", IntegerType, nullable = true)) + .add(StructField("key2", StringType, nullable = true)) + + val valueSchema = new StructType() + .add(StructField("value1", IntegerType, nullable = true)) + .add(StructField("value2", StringType, nullable = true)) + + // adding field should fail + val fieldAddedKeySchema = keySchema.add(StructField("newKey", IntegerType)) + val fieldAddedValueSchema = valueSchema.add(StructField("newValue", IntegerType)) + verifyException(keySchema, valueSchema, fieldAddedKeySchema, fieldAddedValueSchema) + + // removing field should fail + val fieldRemovedKeySchema = StructType(keySchema.dropRight(1)) + val fieldRemovedValueSchema = StructType(valueSchema.drop(1)) + verifyException(keySchema, valueSchema, fieldRemovedKeySchema, fieldRemovedValueSchema) + + // changing the type of field should fail + val typeChangedKeySchema = StructType(keySchema.map(_.copy(dataType = TimestampType))) + val typeChangedValueSchema = StructType(keySchema.map(_.copy(dataType = TimestampType))) + verifyException(keySchema, valueSchema, typeChangedKeySchema, typeChangedValueSchema) + + // changing the nullability of nullable to non-nullable should fail + val nonNullChangedKeySchema = StructType(keySchema.map(_.copy(nullable = false))) + val nonNullChangedValueSchema = StructType(valueSchema.map(_.copy(nullable = false))) + verifyException(keySchema, valueSchema, nonNullChangedKeySchema, nonNullChangedValueSchema) + + // changing the nullability of non-nullable to nullable should be allowed + verifySuccess(nonNullChangedKeySchema, nonNullChangedValueSchema, keySchema, valueSchema) + + // changing the name of field should be allowed + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val fieldNameChangedKeySchema = StructType(keySchema.map(newName)) + val fieldNameChangedValueSchema = StructType(valueSchema.map(newName)) + + verifySuccess(keySchema, valueSchema, fieldNameChangedKeySchema, fieldNameChangedValueSchema) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 0524e2966201..b9bf8b025c51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.streaming import java.io.File import java.util.{Locale, TimeZone} +import scala.annotation.tailrec + import org.apache.commons.io.FileUtils import org.scalatest.Assertions @@ -33,7 +35,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemorySink -import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager +import org.apache.spark.sql.execution.streaming.state.{StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ @@ -753,6 +755,86 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { ) } + testQuietlyWithAllStateVersions("changing schema of state when restarting query") { + withTempDir { tempDir => + val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) + + // if we don't have verification phase on state schema, modified query would throw NPE with + // stack trace which end users would not easily understand + + testStream(aggregated, Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 21), + ExpectFailure[SparkException] { e => + val stateSchemaExc = findStateSchemaNotCompatible(e) + assert(stateSchemaExc.isDefined) + val msg = stateSchemaExc.get.getMessage + assert(msg.contains("Provided schema doesn't match to the schema for existing state")) + // other verifications are presented in StateStoreSuite + } + ) + } + } + + testQuietlyWithAllStateVersions("changing schema of state when restarting query -" + + " schema check off", (SQLConf.STATE_SCHEMA_CHECK_ENABLED.key, "false")) { + withTempDir { tempDir => + val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) + + testStream(aggregated, Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 21), + ExpectFailure[SparkException] { e => + val stateSchemaExc = findStateSchemaNotCompatible(e) + // it would bring other error in runtime, but it shouldn't check schema in any way + assert(stateSchemaExc.isEmpty) + } + ) + } + } + + private def prepareTestForChangingSchemaOfState( + tempDir: File): (MemoryStream[Int], DataFrame) = { + val inputData = MemoryStream[Int] + val aggregated = inputData.toDF() + .selectExpr("value % 10 AS id", "value") + .groupBy($"id") + .agg( + sum("value").as("sum_value"), + avg("value").as("avg_value"), + max("value").as("max_value")) + + testStream(aggregated, Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 11), + CheckLastBatch((1L, 12L, 6.0, 11)), + StopStream + ) + + StateStore.unloadAll() + + val inputData2 = MemoryStream[Int] + val aggregated2 = inputData2.toDF() + .selectExpr("value % 10 AS id", "value") + .groupBy($"id") + .agg( + sum("value").as("sum_value"), + avg("value").as("avg_value"), + collect_list("value").as("values")) + + inputData2.addData(1, 11) + + (inputData2, aggregated2) + } + + @tailrec + private def findStateSchemaNotCompatible(exc: Throwable): Option[StateSchemaNotCompatible] = { + exc match { + case e1: StateSchemaNotCompatible => Some(e1) + case e1 if e1.getCause != null => findStateSchemaNotCompatible(e1.getCause) + case _ => None + } + } /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { From 08a3342cc068a635a1287a821a72bd77e9959b52 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 27 Nov 2020 13:03:25 +0900 Subject: [PATCH 2/3] Reflect review comments --- .../apache/spark/sql/internal/SQLConf.scala | 1 + .../StateSchemaCompatibilityChecker.scala | 58 +++-- .../state/StateStoreCoordinator.scala | 8 +- ...StateSchemaCompatibilityCheckerSuite.scala | 245 +++++++++++++----- 4 files changed, 222 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f99045ce6528..27f3625255d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1298,6 +1298,7 @@ object SQLConf { buildConf("spark.sql.streaming.stateStore.stateSchemaCheck") .doc("When true, Spark will validate the state schema against schema on existing state and " + "fail query if it's incompatible.") + .version("3.1.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 73d56b1e4904..96fabf12df67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.streaming.CheckpointFileManager import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} case class StateSchemaNotCompatible(message: String) extends Exception(message) @@ -42,21 +42,6 @@ class StateSchemaCompatibilityChecker( logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.") val (storedKeySchema, storedValueSchema) = readSchemaFile() - def fieldCompatible(fieldOld: StructField, fieldNew: StructField): Boolean = { - // compatibility for nullable - // - same: OK - // - non-nullable -> nullable: OK - // - nullable -> non-nullable: Not compatible - (fieldOld.dataType == fieldNew.dataType) && - ((fieldOld.nullable == fieldNew.nullable) || - (!fieldOld.nullable && fieldNew.nullable)) - } - - def schemaCompatible(schemaOld: StructType, schemaNew: StructType): Boolean = { - (schemaOld.length == schemaNew.length) && - schemaOld.zip(schemaNew).forall { case (f1, f2) => fieldCompatible(f1, f2) } - } - val errorMsg = "Provided schema doesn't match to the schema for existing state! " + "Please note that Spark allow difference of field name: check count of fields " + "and data type of each field.\n" + @@ -67,8 +52,8 @@ class StateSchemaCompatibilityChecker( if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) { // schema is exactly same - } else if (!schemaCompatible(storedKeySchema, keySchema) || - !schemaCompatible(storedValueSchema, valueSchema)) { + } else if (!schemasCompatible(storedKeySchema, keySchema) || + !schemasCompatible(storedValueSchema, valueSchema)) { logError(errorMsg) throw StateSchemaNotCompatible(errorMsg) } else { @@ -85,9 +70,41 @@ class StateSchemaCompatibilityChecker( } } + private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = + equalsIgnoreCompatibleNullability(storedSchema, schema) + + private def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + // This implementations should be same with DataType.equalsIgnoreCompatibleNullability, except + // this shouldn't check the name equality. + (from, to) match { + case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + (tn || !fn) && + equalsIgnoreCompatibleNullability(fromKey, toKey) && + equalsIgnoreCompatibleNullability(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } + private def readSchemaFile(): (StructType, StructType) = { val inStream = fm.open(schemaFileLocation) try { + val version = inStream.readInt() + // Currently we only support version 1, which we can simplify the version validation and + // the parse logic. + require(version == StateSchemaCompatibilityChecker.VERSION, + s"version $version is not supported.") + val keySchemaStr = inStream.readUTF() val valueSchemaStr = inStream.readUTF() @@ -104,6 +121,7 @@ class StateSchemaCompatibilityChecker( private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = true) try { + outStream.writeInt(StateSchemaCompatibilityChecker.VERSION) outStream.writeUTF(keySchema.json) outStream.writeUTF(valueSchema.json) outStream.close() @@ -118,3 +136,7 @@ class StateSchemaCompatibilityChecker( private def schemaFile(storeCpLocation: Path): Path = new Path(new Path(storeCpLocation, "_metadata"), "schema") } + +object StateSchemaCompatibilityChecker { + val VERSION = 1 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 773fae8e9a00..85642d5098a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -173,11 +173,11 @@ private class StateStoreCoordinator(conf: SparkConf, override val rpcEnv: RpcEnv context.reply(true) case ValidateSchema(providerId, keySchema, valueSchema, checkEnabled) => - // normalize partition ID to validate only once for one state operator - val newProviderId = StateStoreProviderId.withNoPartitionInformation(providerId) + require(providerId.storeId.partitionId == -1, "Expect the normalized partition ID in" + + " provider ID") - val result = schemaValidated.getOrElseUpdate(newProviderId, { - val checker = new StateSchemaCompatibilityChecker(newProviderId, hadoopConf) + val result = schemaValidated.getOrElseUpdate(providerId, { + val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf) // regardless of configuration, we check compatibility to at least write schema file // if necessary diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index e7a555f4e98e..a155bdebd64d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -29,93 +29,202 @@ import org.apache.spark.sql.types._ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { - testQuietly("changing schema of state when restarting query") { - val opId = Random.nextInt(100000) - val partitionId = -1 - - val hadoopConf: Configuration = new Configuration() - - def runSchemaChecker( - dir: String, - queryId: UUID, - newKeySchema: StructType, - newValueSchema: StructType): Unit = { - // in fact, Spark doesn't support online state schema change, so need to check - // schema only once for each running of JVM - val providerId = StateStoreProviderId( - StateStoreId(dir, opId, partitionId), queryId) - - new StateSchemaCompatibilityChecker(providerId, hadoopConf) - .check(newKeySchema, newValueSchema) - } - - def verifyException( - oldKeySchema: StructType, - oldValueSchema: StructType, - newKeySchema: StructType, - newValueSchema: StructType): Unit = { - val dir = newDir() - val queryId = UUID.randomUUID() - runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) - - val e = intercept[StateSchemaNotCompatible] { - runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) - } - - e.getMessage.contains("Provided schema doesn't match to the schema for existing state!") - e.getMessage.contains(newKeySchema.json) - e.getMessage.contains(newValueSchema.json) - e.getMessage.contains(oldKeySchema.json) - e.getMessage.contains(oldValueSchema.json) - } + private val hadoopConf: Configuration = new Configuration() + private val opId = Random.nextInt(100000) + private val partitionId = -1 - def verifySuccess( - oldKeySchema: StructType, - oldValueSchema: StructType, - newKeySchema: StructType, - newValueSchema: StructType): Unit = { - val dir = newDir() - val queryId = UUID.randomUUID() - runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) - runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) - } + private val structSchema = new StructType() + .add(StructField("nested1", IntegerType, nullable = true)) + .add(StructField("nested2", StringType, nullable = true)) - val keySchema = new StructType() - .add(StructField("key1", IntegerType, nullable = true)) - .add(StructField("key2", StringType, nullable = true)) + private val keySchema = new StructType() + .add(StructField("key1", IntegerType, nullable = true)) + .add(StructField("key2", StringType, nullable = true)) + .add(StructField("key3", structSchema, nullable = true)) - val valueSchema = new StructType() - .add(StructField("value1", IntegerType, nullable = true)) - .add(StructField("value2", StringType, nullable = true)) + private val valueSchema = new StructType() + .add(StructField("value1", IntegerType, nullable = true)) + .add(StructField("value2", StringType, nullable = true)) + .add(StructField("value3", structSchema, nullable = true)) - // adding field should fail + test("adding field to key should fail") { val fieldAddedKeySchema = keySchema.add(StructField("newKey", IntegerType)) + verifyException(keySchema, valueSchema, fieldAddedKeySchema, valueSchema) + } + + test("adding field to value should fail") { val fieldAddedValueSchema = valueSchema.add(StructField("newValue", IntegerType)) - verifyException(keySchema, valueSchema, fieldAddedKeySchema, fieldAddedValueSchema) + verifyException(keySchema, valueSchema, keySchema, fieldAddedValueSchema) + } + + test("adding nested field in key should fail") { + val fieldAddedNestedSchema = structSchema.add(StructField("newNested", IntegerType)) + val newKeySchema = applyNewSchemaToNestedFieldInKey(fieldAddedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } - // removing field should fail + test("adding nested field in value should fail") { + val fieldAddedNestedSchema = structSchema.add(StructField("newNested", IntegerType)) + val newValueSchema = applyNewSchemaToNestedFieldInValue(fieldAddedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("removing field from key should fail") { val fieldRemovedKeySchema = StructType(keySchema.dropRight(1)) + verifyException(keySchema, valueSchema, fieldRemovedKeySchema, valueSchema) + } + + test("removing field from value should fail") { val fieldRemovedValueSchema = StructType(valueSchema.drop(1)) - verifyException(keySchema, valueSchema, fieldRemovedKeySchema, fieldRemovedValueSchema) + verifyException(keySchema, valueSchema, keySchema, fieldRemovedValueSchema) + } + + test("removing nested field from key should fail") { + val fieldRemovedNestedSchema = StructType(structSchema.dropRight(1)) + val newKeySchema = applyNewSchemaToNestedFieldInKey(fieldRemovedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } - // changing the type of field should fail + test("removing nested field from value should fail") { + val fieldRemovedNestedSchema = StructType(structSchema.drop(1)) + val newValueSchema = applyNewSchemaToNestedFieldInValue(fieldRemovedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("changing the type of field in key should fail") { val typeChangedKeySchema = StructType(keySchema.map(_.copy(dataType = TimestampType))) - val typeChangedValueSchema = StructType(keySchema.map(_.copy(dataType = TimestampType))) - verifyException(keySchema, valueSchema, typeChangedKeySchema, typeChangedValueSchema) + verifyException(keySchema, valueSchema, typeChangedKeySchema, valueSchema) + } - // changing the nullability of nullable to non-nullable should fail + test("changing the type of field in value should fail") { + val typeChangedValueSchema = StructType(valueSchema.map(_.copy(dataType = TimestampType))) + verifyException(keySchema, valueSchema, keySchema, typeChangedValueSchema) + } + + test("changing the type of nested field in key should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(dataType = TimestampType))) + val newKeySchema = applyNewSchemaToNestedFieldInKey(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } + + test("changing the type of nested field in value should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(dataType = TimestampType))) + val newValueSchema = applyNewSchemaToNestedFieldInValue(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } + + test("changing the nullability of nullable to non-nullable in key should fail") { val nonNullChangedKeySchema = StructType(keySchema.map(_.copy(nullable = false))) + verifyException(keySchema, valueSchema, nonNullChangedKeySchema, valueSchema) + } + + test("changing the nullability of nullable to non-nullable in value should fail") { val nonNullChangedValueSchema = StructType(valueSchema.map(_.copy(nullable = false))) - verifyException(keySchema, valueSchema, nonNullChangedKeySchema, nonNullChangedValueSchema) + verifyException(keySchema, valueSchema, keySchema, nonNullChangedValueSchema) + } + + test("changing the nullability of nullable to nonnullable in nested field in key should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) + val newKeySchema = applyNewSchemaToNestedFieldInKey(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, newKeySchema, valueSchema) + } - // changing the nullability of non-nullable to nullable should be allowed - verifySuccess(nonNullChangedKeySchema, nonNullChangedValueSchema, keySchema, valueSchema) + test("changing the nullability of nullable to nonnullable in nested field in value should fail") { + val typeChangedNestedSchema = StructType(structSchema.map(_.copy(nullable = false))) + val newValueSchema = applyNewSchemaToNestedFieldInValue(typeChangedNestedSchema) + verifyException(keySchema, valueSchema, keySchema, newValueSchema) + } - // changing the name of field should be allowed + test("changing the name of field in key should be allowed") { val newName: StructField => StructField = f => f.copy(name = f.name + "_new") val fieldNameChangedKeySchema = StructType(keySchema.map(newName)) + verifySuccess(keySchema, valueSchema, fieldNameChangedKeySchema, valueSchema) + } + + test("changing the name of field in value should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") val fieldNameChangedValueSchema = StructType(valueSchema.map(newName)) + verifySuccess(keySchema, valueSchema, keySchema, fieldNameChangedValueSchema) + } + + test("changing the name of nested field in key should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val newNestedFieldsSchema = StructType(structSchema.map(newName)) + val fieldNameChangedKeySchema = applyNewSchemaToNestedFieldInKey(newNestedFieldsSchema) + verifySuccess(keySchema, valueSchema, fieldNameChangedKeySchema, valueSchema) + } + + test("changing the name of nested field in value should be allowed") { + val newName: StructField => StructField = f => f.copy(name = f.name + "_new") + val newNestedFieldsSchema = StructType(structSchema.map(newName)) + val fieldNameChangedValueSchema = applyNewSchemaToNestedFieldInValue(newNestedFieldsSchema) + verifySuccess(keySchema, valueSchema, keySchema, fieldNameChangedValueSchema) + } + + private def applyNewSchemaToNestedFieldInKey(newNestedSchema: StructType): StructType = { + applyNewSchemaToNestedField(keySchema, newNestedSchema, "key3") + } + + private def applyNewSchemaToNestedFieldInValue(newNestedSchema: StructType): StructType = { + applyNewSchemaToNestedField(valueSchema, newNestedSchema, "value3") + } + + private def applyNewSchemaToNestedField( + originSchema: StructType, + newNestedSchema: StructType, + fieldName: String): StructType = { + val newFields = originSchema.map { field => + if (field.name == fieldName) { + field.copy(dataType = newNestedSchema) + } else { + field + } + } + StructType(newFields) + } + + private def runSchemaChecker( + dir: String, + queryId: UUID, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + // in fact, Spark doesn't support online state schema change, so need to check + // schema only once for each running of JVM + val providerId = StateStoreProviderId( + StateStoreId(dir, opId, partitionId), queryId) + + new StateSchemaCompatibilityChecker(providerId, hadoopConf) + .check(newKeySchema, newValueSchema) + } + + private def verifyException( + oldKeySchema: StructType, + oldValueSchema: StructType, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + val dir = newDir() + val queryId = UUID.randomUUID() + runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) + + val e = intercept[StateSchemaNotCompatible] { + runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) + } + + e.getMessage.contains("Provided schema doesn't match to the schema for existing state!") + e.getMessage.contains(newKeySchema.json) + e.getMessage.contains(newValueSchema.json) + e.getMessage.contains(oldKeySchema.json) + e.getMessage.contains(oldValueSchema.json) + } - verifySuccess(keySchema, valueSchema, fieldNameChangedKeySchema, fieldNameChangedValueSchema) + private def verifySuccess( + oldKeySchema: StructType, + oldValueSchema: StructType, + newKeySchema: StructType, + newValueSchema: StructType): Unit = { + val dir = newDir() + val queryId = UUID.randomUUID() + runSchemaChecker(dir, queryId, oldKeySchema, oldValueSchema) + runSchemaChecker(dir, queryId, newKeySchema, newValueSchema) } } From 2901429ab24ca713bf74d051a84eec2f6a2433ca Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Thu, 3 Dec 2020 20:34:17 +0900 Subject: [PATCH 3/3] Address review comments --- .../org/apache/spark/sql/types/DataType.scala | 38 +++++++++-- .../execution/streaming/HDFSMetadataLog.scala | 32 +-------- .../streaming/MetadataVersionUtil.scala | 51 ++++++++++++++ .../StateSchemaCompatibilityChecker.scala | 66 ++++++------------- .../streaming/state/StateStore.scala | 48 ++++++-------- .../state/StateStoreCoordinator.scala | 49 ++------------ ...StateSchemaCompatibilityCheckerSuite.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 7 +- ...ngStateStoreFormatCompatibilitySuite.scala | 21 ++++-- 9 files changed, 152 insertions(+), 162 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e4ee6eb377a4..9e820f0796a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -307,21 +307,49 @@ object DataType { * of `fromField.nullable` and `toField.nullable` are false. */ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + equalsIgnoreCompatibleNullability(from, to, ignoreName = false) + } + + /** + * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType, and + * also the field name. It compares based on the position. + * + * Compatible nullability is defined as follows: + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` + * if and only if `to.containsNull` is true, or both of `from.containsNull` and + * `to.containsNull` are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` + * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` + * if and only if for all every pair of fields, `to.nullable` is true, or both + * of `fromField.nullable` and `toField.nullable` are false. + */ + private[sql] def equalsIgnoreNameAndCompatibleNullability( + from: DataType, + to: DataType): Boolean = { + equalsIgnoreCompatibleNullability(from, to, ignoreName = true) + } + + private def equalsIgnoreCompatibleNullability( + from: DataType, + to: DataType, + ignoreName: Boolean = false): Boolean = { (from, to) match { case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => - (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement, ignoreName) case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey) && - equalsIgnoreCompatibleNullability(fromValue, toValue) + equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) && + equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName) case (StructType(fromFields), StructType(toFields)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && + (ignoreName || fromField.name == toField.name) && (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName) } case (fromDataType, toDataType) => fromDataType == toDataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 893639a86c88..b87a5b49eb6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -267,36 +267,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - /** - * Parse the log version from the given `text` -- will throw exception when the parsed version - * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", - * "v123xyz" etc.) - */ - private[sql] def validateVersion(text: String, maxSupportedVersion: Int): Int = { - if (text.length > 0 && text(0) == 'v') { - val version = - try { - text.substring(1, text.length).toInt - } catch { - case _: NumberFormatException => - throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + - s"version from $text.") - } - if (version > 0) { - if (version > maxSupportedVersion) { - throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + - s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + - s"by a newer version of Spark and cannot be read by this version. Please upgrade.") - } else { - return version - } - } - } - - // reaching here means we failed to read the correct log version - throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + - s"version from $text.") - } + private[sql] def validateVersion(text: String, maxSupportedVersion: Int): Int = + MetadataVersionUtil.validateVersion(text, maxSupportedVersion) } object HDFSMetadataLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala new file mode 100644 index 000000000000..548f2aa5d5c5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataVersionUtil.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +object MetadataVersionUtil { + /** + * Parse the log version from the given `text` -- will throw exception when the parsed version + * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", + * "v123xyz" etc.) + */ + def validateVersion(text: String, maxSupportedVersion: Int): Int = { + if (text.length > 0 && text(0) == 'v') { + val version = + try { + text.substring(1, text.length).toInt + } catch { + case _: NumberFormatException => + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } + if (version > 0) { + if (version > maxSupportedVersion) { + throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " + + s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " + + s"by a newer version of Spark and cannot be read by this version. Please upgrade.") + } else { + return version + } + } + } + + // reaching here means we failed to read the correct log version + throw new IllegalStateException(s"Log file was malformed: failed to read correct log " + + s"version from $text.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 96fabf12df67..4ac12c089c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -21,9 +21,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} case class StateSchemaNotCompatible(message: String) extends Exception(message) @@ -41,27 +41,25 @@ class StateSchemaCompatibilityChecker( if (fm.exists(schemaFileLocation)) { logDebug(s"Schema file for provider $providerId exists. Comparing with provided schema.") val (storedKeySchema, storedValueSchema) = readSchemaFile() - - val errorMsg = "Provided schema doesn't match to the schema for existing state! " + - "Please note that Spark allow difference of field name: check count of fields " + - "and data type of each field.\n" + - s"- provided schema: key $keySchema value $valueSchema\n" + - s"- existing schema: key $storedKeySchema value $storedValueSchema\n" + - s"If you want to force running query without schema validation, please set " + - s"${SQLConf.STATE_SCHEMA_CHECK_ENABLED.key} to false." - if (storedKeySchema.equals(keySchema) && storedValueSchema.equals(valueSchema)) { // schema is exactly same } else if (!schemasCompatible(storedKeySchema, keySchema) || !schemasCompatible(storedValueSchema, valueSchema)) { + val errorMsg = "Provided schema doesn't match to the schema for existing state! " + + "Please note that Spark allow difference of field name: check count of fields " + + "and data type of each field.\n" + + s"- Provided key schema: $keySchema\n" + + s"- Provided value schema: $valueSchema\n" + + s"- Existing key schema: $storedKeySchema\n" + + s"- Existing value schema: $storedValueSchema\n" + + s"If you want to force running query without schema validation, please set " + + s"${SQLConf.STATE_SCHEMA_CHECK_ENABLED.key} to false.\n" + + "Please note running query with incompatible schema could cause indeterministic" + + " behavior." logError(errorMsg) throw StateSchemaNotCompatible(errorMsg) } else { - logInfo("Detected schema change which is compatible: will overwrite schema file to new.") - // It tries best-effort to overwrite current schema file. - // the schema validation doesn't break even it fails, though it might miss on detecting - // change which is not a big deal. - createSchemaFile(keySchema, valueSchema) + logInfo("Detected schema change which is compatible. Allowing to put rows.") } } else { // schema doesn't exist, create one now @@ -71,39 +69,17 @@ class StateSchemaCompatibilityChecker( } private def schemasCompatible(storedSchema: StructType, schema: StructType): Boolean = - equalsIgnoreCompatibleNullability(storedSchema, schema) - - private def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { - // This implementations should be same with DataType.equalsIgnoreCompatibleNullability, except - // this shouldn't check the name equality. - (from, to) match { - case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => - (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey) && - equalsIgnoreCompatibleNullability(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } - } + DataType.equalsIgnoreNameAndCompatibleNullability(storedSchema, schema) private def readSchemaFile(): (StructType, StructType) = { val inStream = fm.open(schemaFileLocation) try { - val version = inStream.readInt() + val versionStr = inStream.readUTF() // Currently we only support version 1, which we can simplify the version validation and // the parse logic. - require(version == StateSchemaCompatibilityChecker.VERSION, - s"version $version is not supported.") + val version = MetadataVersionUtil.validateVersion(versionStr, + StateSchemaCompatibilityChecker.VERSION) + require(version == 1) val keySchemaStr = inStream.readUTF() val valueSchemaStr = inStream.readUTF() @@ -119,9 +95,9 @@ class StateSchemaCompatibilityChecker( } private def createSchemaFile(keySchema: StructType, valueSchema: StructType): Unit = { - val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = true) + val outStream = fm.createAtomic(schemaFileLocation, overwriteIfPossible = false) try { - outStream.writeInt(StateSchemaCompatibilityChecker.VERSION) + outStream.writeUTF(s"v${StateSchemaCompatibilityChecker.VERSION}") outStream.writeUTF(keySchema.json) outStream.writeUTF(valueSchema.json) outStream.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index e959c58f0323..ab67c19783ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable +import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -328,11 +329,6 @@ object StateStoreProviderId { stateInfo.checkpointLocation, stateInfo.operatorId, partitionIndex, storeName) StateStoreProviderId(storeId, stateInfo.queryRunId) } - - private[sql] def withNoPartitionInformation( - providerId: StateStoreProviderId): StateStoreProviderId = { - providerId.copy(storeId = providerId.storeId.copy(partitionId = -1)) - } } /** @@ -391,12 +387,13 @@ object StateStore extends Logging { val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 + val PARTITION_ID_TO_CHECK_SCHEMA = 0 @GuardedBy("loadedProviders") private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]() @GuardedBy("loadedProviders") - private val schemaValidated = new mutable.HashSet[StateStoreProviderId]() + private val schemaValidated = new mutable.HashMap[StateStoreProviderId, Option[Throwable]]() /** * Runs the `task` periodically and automatically cancels it if there is an exception. `onError` @@ -476,11 +473,22 @@ object StateStore extends Logging { loadedProviders.synchronized { startMaintenanceIfNeeded() - val newProvIdSchemaCheck = StateStoreProviderId.withNoPartitionInformation(storeProviderId) - if (!schemaValidated.contains(newProvIdSchemaCheck)) { - validateSchema(newProvIdSchemaCheck, keySchema, valueSchema, - storeConf.stateSchemaCheckEnabled) - schemaValidated.add(newProvIdSchemaCheck) + if (storeProviderId.storeId.partitionId == PARTITION_ID_TO_CHECK_SCHEMA) { + val result = schemaValidated.getOrElseUpdate(storeProviderId, { + val checker = new StateSchemaCompatibilityChecker(storeProviderId, hadoopConf) + // regardless of configuration, we check compatibility to at least write schema file + // if necessary + val ret = Try(checker.check(keySchema, valueSchema)).toEither.fold(Some(_), _ => None) + if (storeConf.stateSchemaCheckEnabled) { + ret + } else { + None + } + }) + + if (result.isDefined) { + throw result.get + } } val provider = loadedProviders.getOrElseUpdate( @@ -586,24 +594,6 @@ object StateStore extends Logging { } } - private def validateSchema( - storeProviderId: StateStoreProviderId, - keySchema: StructType, - valueSchema: StructType, - checkEnabled: Boolean): Unit = { - if (SparkEnv.get != null) { - val validated = coordinatorRef.flatMap( - _.validateSchema(storeProviderId, keySchema, valueSchema, checkEnabled)) - - validated match { - case Some(exc) => - // driver would log the information, so just re-throw here - throw exc - case None => - } - } - } - private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 85642d5098a1..2b14d37ee21e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.execution.streaming.state import java.util.UUID import scala.collection.mutable -import scala.util.Try -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.types.StructType import org.apache.spark.util.RpcUtils /** Trait representing all messages to [[StateStoreCoordinator]] */ @@ -46,12 +43,6 @@ private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executo private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage -private case class ValidateSchema( - storeProviderId: StateStoreProviderId, - keySchema: StructType, - valueSchema: StructType, - checkEnabled: Boolean) extends StateStoreCoordinatorMessage - private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage @@ -68,8 +59,7 @@ object StateStoreCoordinatorRef extends Logging { */ def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { - - val coordinator = new StateStoreCoordinator(env.conf, env.rpcEnv) + val coordinator = new StateStoreCoordinator(env.rpcEnv) val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator) logInfo("Registered StateStoreCoordinator endpoint") new StateStoreCoordinatorRef(coordinatorRef) @@ -93,6 +83,7 @@ object StateStoreCoordinatorRef extends Logging { * [[StateStore]]s across all the executors, and get their locations for job scheduling. */ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { + private[sql] def reportActiveInstance( stateStoreProviderId: StateStoreProviderId, host: String, @@ -117,16 +108,6 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId)) } - /** Validate state store operator's schema to see it's compatible with existing schema */ - private[sql] def validateSchema( - storeProviderId: StateStoreProviderId, - keySchema: StructType, - valueSchema: StructType, - checkEnabled: Boolean): Option[Exception] = { - rpcEndpointRef.askSync[Option[Exception]]( - ValidateSchema(storeProviderId, keySchema, valueSchema, checkEnabled)) - } - private[state] def stop(): Unit = { rpcEndpointRef.askSync[Boolean](StopCoordinator) } @@ -137,12 +118,9 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(conf: SparkConf, override val rpcEnv: RpcEnv) +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation] - private val schemaValidated = new mutable.HashMap[StateStoreProviderId, Option[Throwable]] - - private lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => @@ -172,25 +150,6 @@ private class StateStoreCoordinator(conf: SparkConf, override val rpcEnv: RpcEnv storeIdsToRemove.mkString(", ")) context.reply(true) - case ValidateSchema(providerId, keySchema, valueSchema, checkEnabled) => - require(providerId.storeId.partitionId == -1, "Expect the normalized partition ID in" + - " provider ID") - - val result = schemaValidated.getOrElseUpdate(providerId, { - val checker = new StateSchemaCompatibilityChecker(providerId, hadoopConf) - - // regardless of configuration, we check compatibility to at least write schema file - // if necessary - val ret = Try(checker.check(keySchema, valueSchema)).toEither.fold(Some(_), _ => None) - if (checkEnabled) { - ret - } else { - None - } - }) - - context.reply(result) - case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered logInfo("StateStoreCoordinator stopped") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index a155bdebd64d..4eb7603b316a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -31,7 +31,7 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { private val hadoopConf: Configuration = new Configuration() private val opId = Random.nextInt(100000) - private val partitionId = -1 + private val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA private val structSchema = new StructType() .add(StructField("nested1", IntegerType, nullable = true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index b9bf8b025c51..491b0d8b2c26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -755,7 +755,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { ) } - testQuietlyWithAllStateVersions("changing schema of state when restarting query") { + testQuietlyWithAllStateVersions("changing schema of state when restarting query", + (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false")) { withTempDir { tempDir => val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) @@ -777,7 +778,9 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { } testQuietlyWithAllStateVersions("changing schema of state when restarting query -" + - " schema check off", (SQLConf.STATE_SCHEMA_CHECK_ENABLED.key, "false")) { + " schema check off", + (SQLConf.STATE_SCHEMA_CHECK_ENABLED.key, "false"), + (SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key, "false")) { withTempDir { tempDir => val (inputData, aggregated) = prepareTestForChangingSchemaOfState(tempDir) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala index 33f6b02acb6d..1032d6c5b6ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStateStoreFormatCompatibilitySuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import java.io.File +import scala.annotation.tailrec + import org.apache.commons.io.FileUtils import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{InvalidUnsafeRowException, StateSchemaNotCompatible} import org.apache.spark.sql.functions._ import org.apache.spark.util.Utils @@ -239,11 +242,19 @@ class StreamingStateStoreFormatCompatibilitySuite extends StreamTest { CheckAnswer(Row(0, 20, Seq(0, 2, 4, 6, 8)), Row(1, 25, Seq(1, 3, 5, 7, 9))) */ AddData(inputData, 10 to 19: _*), - ExpectFailure[SparkException](e => { - // Check the exception message to make sure the state store format changing. - assert(e.getCause.getCause.getMessage.contains( - "The streaming query failed by state format invalidation.")) - }) + ExpectFailure[SparkException] { e => + assert(findStateSchemaException(e)) + } ) } + + @tailrec + private def findStateSchemaException(exc: Throwable): Boolean = { + exc match { + case _: StateSchemaNotCompatible => true + case _: InvalidUnsafeRowException => true + case e1 if e1.getCause != null => findStateSchemaException(e1.getCause) + case _ => false + } + } }