diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 408c8f81f17ba..77bc0ba5548dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -169,13 +169,15 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( */ private def compact(batchId: Long, logs: Array[T]): Boolean = { val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) - val allLogs = validBatches.flatMap(batchId => super.get(batchId)).flatten ++ logs - if (super.add(batchId, compactLogs(allLogs).toArray)) { - true - } else { - // Return false as there is another writer. - false - } + val allLogs = validBatches.map { id => + super.get(id).getOrElse { + throw new IllegalStateException( + s"${batchIdToPath(id)} doesn't exist when compacting batch $batchId " + + s"(compactInterval: $compactInterval)") + } + }.flatten ++ logs + // Return false as there is another writer. + super.add(batchId, compactLogs(allLogs).toArray) } /** @@ -190,7 +192,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( if (latestId >= 0) { try { val logs = - getAllValidBatches(latestId, compactInterval).flatMap(id => super.get(id)).flatten + getAllValidBatches(latestId, compactInterval).map { id => + super.get(id).getOrElse { + throw new IllegalStateException( + s"${batchIdToPath(id)} doesn't exist " + + s"(latestId: $latestId, compactInterval: $compactInterval)") + } + }.flatten return compactLogs(logs).toArray } catch { case e: IOException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 33e6a1d5d6e18..8628471fdb925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -115,7 +115,10 @@ class FileStreamSourceLog( Map.empty[Long, Option[Array[FileEntry]]] } - (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) + val batches = + (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) + HDFSMetadataLog.verifyBatchIds(batches.map(_._1), startId, endId) + batches } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 46bfc297931fb..5f8973fd09460 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -123,7 +123,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: serialize(metadata, output) return Some(tempPath) } finally { - IOUtils.closeQuietly(output) + output.close() } } catch { case e: FileAlreadyExistsException => @@ -211,13 +211,17 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] = { + assert(startId.isEmpty || endId.isEmpty || startId.get <= endId.get) val files = fileManager.list(metadataPath, batchFilesFilter) val batchIds = files .map(f => pathToBatchId(f.getPath)) .filter { batchId => (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) - } - batchIds.sorted.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { + }.sorted + + verifyBatchIds(batchIds, startId, endId) + + batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => (batchId, metadataOption.get) } @@ -437,4 +441,51 @@ object HDFSMetadataLog { } } } + + /** + * Verify if batchIds are continuous and between `startId` and `endId`. + * + * @param batchIds the sorted ids to verify. + * @param startId the start id. If it's set, batchIds should start with this id. + * @param endId the start id. If it's set, batchIds should end with this id. + */ + def verifyBatchIds(batchIds: Seq[Long], startId: Option[Long], endId: Option[Long]): Unit = { + // Verify that we can get all batches between `startId` and `endId`. + if (startId.isDefined || endId.isDefined) { + if (batchIds.isEmpty) { + throw new IllegalStateException(s"batch ${startId.orElse(endId).get} doesn't exist") + } + if (startId.isDefined) { + val minBatchId = batchIds.head + assert(minBatchId >= startId.get) + if (minBatchId != startId.get) { + val missingBatchIds = startId.get to minBatchId + throw new IllegalStateException( + s"batches (${missingBatchIds.mkString(", ")}) don't exist " + + s"(startId: $startId, endId: $endId)") + } + } + + if (endId.isDefined) { + val maxBatchId = batchIds.last + assert(maxBatchId <= endId.get) + if (maxBatchId != endId.get) { + val missingBatchIds = maxBatchId to endId.get + throw new IllegalStateException( + s"batches (${missingBatchIds.mkString(", ")}) don't exist " + + s"(startId: $startId, endId: $endId)") + } + } + } + + if (batchIds.nonEmpty) { + val minBatchId = batchIds.head + val maxBatchId = batchIds.last + val missingBatchIds = (minBatchId to maxBatchId).toSet -- batchIds + if (missingBatchIds.nonEmpty) { + throw new IllegalStateException(s"batches (${missingBatchIds.mkString(", ")}) " + + s"don't exist (startId: $startId, endId: $endId)") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 63c4dc17fddc5..16db353eef54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -429,7 +429,10 @@ class StreamExecution( availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ - offsetLog.get(latestBatchId - 1).foreach { secondLatestBatchId => + if (latestBatchId != 0) { + val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse { + throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist") + } committedOffsets = secondLatestBatchId.toStreamProgress(sources) } @@ -568,10 +571,14 @@ class StreamExecution( // Now that we've updated the scheduler's persistent checkpoint, it is safe for the // sources to discard data from the previous batch. - val prevBatchOff = offsetLog.get(currentBatchId - 1) - if (prevBatchOff.isDefined) { - prevBatchOff.get.toStreamProgress(sources).foreach { - case (src, off) => src.commit(off) + if (currentBatchId != 0) { + val prevBatchOff = offsetLog.get(currentBatchId - 1) + if (prevBatchOff.isDefined) { + prevBatchOff.get.toStreamProgress(sources).foreach { + case (src, off) => src.commit(off) + } + } else { + throw new IllegalStateException(s"batch $currentBatchId doesn't exist") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 7689bc03a4ccf..48e70e48b1799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -259,6 +259,23 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { fm.rename(path2, path3) } } + + test("verifyBatchIds") { + import HDFSMetadataLog.verifyBatchIds + verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) + verifyBatchIds(Seq(1L), Some(1L), Some(1L)) + verifyBatchIds(Seq(1L, 2L, 3L), None, Some(3L)) + verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), None) + verifyBatchIds(Seq(1L, 2L, 3L), None, None) + + intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), None)) + intercept[IllegalStateException](verifyBatchIds(Seq(), None, Some(1L))) + intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), Some(1L))) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), None)) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), None, Some(5L))) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), Some(5L))) + intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) + } } /** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 2108b118bf059..e2ec690d90e52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1314,6 +1314,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) + assert(metadataLog.add(1, Array(FileEntry(s"$scheme:///file2", 200L, 0)))) val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, dir.getAbsolutePath, Map.empty)