diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index fb93033bb15d4..9eb206457809c 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -62,7 +62,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Schema evolution is not supported yet. Here we only pick a single random sample file to // figure out the schema of the whole dataset. val sampleFile = - if (conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true)) { + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf)) { files.find(_.getPath.getName.endsWith(".avro")).getOrElse { throw new FileNotFoundException( "No Avro files found. Hadoop option \"avro.mapred.ignore.inputs.without.extension\" " + @@ -170,10 +170,7 @@ private[avro] class AvroFileFormat extends FileFormat with DataSourceRegister { // Doing input file filtering is improper because we may generate empty tasks that process no // input files but stress the scheduler. We should probably add a more general input file // filtering mechanism for `FileFormat` data sources. See SPARK-16317. - if ( - conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, true) && - !file.filePath.endsWith(".avro") - ) { + if (AvroFileFormat.ignoreFilesWithoutExtensions(conf) && !file.filePath.endsWith(".avro")) { Iterator.empty } else { val reader = { @@ -278,4 +275,11 @@ private[avro] object AvroFileFormat { value.readFields(new DataInputStream(in)) } } + + def ignoreFilesWithoutExtensions(conf: Configuration): Boolean = { + // Files without .avro extensions are not ignored by default + val defaultValue = false + + conf.getBoolean(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, defaultValue) + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 9c6526b29dca3..446b42124ceca 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.avro import java.io._ -import java.nio.file.Files +import java.net.URL +import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.util.{TimeZone, UUID} @@ -622,7 +623,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { intercept[FileNotFoundException] { withTempPath { dir => FileUtils.touch(new File(dir, "test")) - spark.read.avro(dir.toString) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + spark.read.avro(dir.toString) + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) } } } @@ -684,12 +690,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Files.createFile(new File(tempSaveDir, "non-avro").toPath) - val newDf = spark - .read - .option(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") - .avro(tempSaveDir) + val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration + val count = try { + hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true") + val newDf = spark + .read + .avro(tempSaveDir) + newDf.count() + } finally { + hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty) + } - assert(newDf.count == 8) + assert(count == 8) } } @@ -805,4 +817,23 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(readDf.collect().sameElements(writeDf.collect())) } } + + test("SPARK-24805: do not ignore files without .avro extension by default") { + withTempDir { dir => + Files.copy( + Paths.get(new URL(episodesAvro).toURI), + Paths.get(dir.getCanonicalPath, "episodes")) + + val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes" + val df1 = spark.read.avro(fileWithoutExtension) + assert(df1.count == 8) + + val schema = new StructType() + .add("title", StringType) + .add("air_date", StringType) + .add("doctor", IntegerType) + val df2 = spark.read.schema(schema).avro(fileWithoutExtension) + assert(df2.count == 8) + } + } }