From ad7ef8137be04352730d2df7339cb5f00c3190c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 21 Oct 2016 16:29:18 +0800 Subject: [PATCH] FileStreamSource should not infer partitions in every batch --- .../execution/datasources/DataSource.scala | 26 +++++++++++++------ .../streaming/FileStreamSource.scala | 2 ++ .../streaming/FileStreamSourceSuite.scala | 2 +- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 92b1fff7d8127..17da606580eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -75,7 +75,7 @@ case class DataSource( bucketSpec: Option[BucketSpec] = None, options: Map[String, String] = Map.empty) extends Logging { - case class SourceInfo(name: String, schema: StructType) + case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) lazy val providingClass: Class[_] = lookupDataSource(className) lazy val sourceInfo = sourceSchema() @@ -186,8 +186,11 @@ case class DataSource( } } - private def inferFileFormatSchema(format: FileFormat): StructType = { - userSpecifiedSchema.orElse { + /** + * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. + */ + private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { + userSpecifiedSchema.map(_ -> partitionColumns).orElse { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => @@ -197,14 +200,14 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) - val partitionCols = fileCatalog.partitionSpec().partitionColumns.fields + val partitionSchema = fileCatalog.partitionSpec().partitionColumns val inferred = format.inferSchema( sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) inferred.map { inferredSchema => - StructType(inferredSchema ++ partitionCols) + StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name) } }.getOrElse { throw new AnalysisException("Unable to infer schema. It must be specified manually.") @@ -217,7 +220,7 @@ case class DataSource( case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( sparkSession.sqlContext, userSpecifiedSchema, className, options) - SourceInfo(name, schema) + SourceInfo(name, schema, Nil) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) @@ -246,7 +249,8 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - SourceInfo(s"FileSource[$path]", inferFileFormatSchema(format)) + val (schema, partCols) = inferFileFormatSchema(format) + SourceInfo(s"FileSource[$path]", schema, partCols) case _ => throw new UnsupportedOperationException( @@ -266,7 +270,13 @@ case class DataSource( throw new IllegalArgumentException("'path' is not specified") }) new FileStreamSource( - sparkSession, path, className, sourceInfo.schema, metadataPath, options) + sparkSession = sparkSession, + path = path, + fileFormatClassName = className, + schema = sourceInfo.schema, + partitionColumns = sourceInfo.partitionColumns, + metadataPath = metadataPath, + options = options) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 614a6261e7c28..115edf7ab2b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -35,6 +35,7 @@ class FileStreamSource( path: String, fileFormatClassName: String, override val schema: StructType, + partitionColumns: Seq[String], metadataPath: String, options: Map[String, String]) extends Source with Logging { @@ -142,6 +143,7 @@ class FileStreamSource( sparkSession, paths = files.map(_.path), userSpecifiedSchema = Some(schema), + partitionColumns = partitionColumns, className = fileFormatClassName, options = optionsWithPartitionBasePath) Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala index 3e1e1126f9e6b..4a47c04d3f084 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -94,7 +94,7 @@ class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) - val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), + val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, dir.getAbsolutePath, Map.empty) // this method should throw an exception if `fs.exists` is called during resolveRelation newSource.getBatch(None, LongOffset(1))