diff --git a/README.md b/README.md index 56c92e1..84c3aa1 100755 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ When reading files the API accepts several options: * `charset`: defaults to 'UTF-8' but can be set to other valid charset names * `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default * `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`. -* `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified. +* `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec` or one of case-insensitive shorten names (`bzip2`, `gzip`, `lz4`, and `snappy`). Defaults to no compression when a codec is not specified. * `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details. diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 5a09176..ccf539b 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -326,7 +326,7 @@ case class CsvRelation protected[spark] ( } // Write the data. We assume that schema isn't changed, and we won't update it. - val codecClass = compresionCodecClass(codec) + val codecClass = CompressionCodecs.getCodecClass(codec) data.saveAsCsvFile(filesystemPath.toString, Map("delimiter" -> delimiter.toString), codecClass) } else { diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index c2e1481..10e377d 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -19,7 +19,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import com.databricks.spark.csv.util.{ParserLibs, TextFile, TypeCast} +import com.databricks.spark.csv.util.{CompressionCodecs, ParserLibs, TextFile, TypeCast} /** * Provides access to CSV data from pure SQL statements (i.e. for users of the @@ -183,7 +183,7 @@ class DefaultSource } if (doSave) { // Only save data when the save mode is not ignore. - val codecClass = compresionCodecClass(parameters.getOrElse("codec", null)) + val codecClass = CompressionCodecs.getCodecClass(parameters.getOrElse("codec", null)) data.saveAsCsvFile(path, parameters, codecClass) } diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 71dab64..6251b0d 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -26,16 +26,6 @@ package object csv { val defaultCsvFormat = CSVFormat.DEFAULT.withRecordSeparator(System.getProperty("line.separator", "\n")) - private[csv] def compresionCodecClass(className: String): Class[_ <: CompressionCodec] = { - className match { - case null => null - case codec => - // scalastyle:off classforname - Class.forName(codec).asInstanceOf[Class[CompressionCodec]] - // scalastyle:on classforname - } - } - /** * Adds a method, `csvFile`, to SQLContext that allows reading CSV data. */ diff --git a/src/main/scala/com/databricks/spark/csv/util/CompressionCodecs.scala b/src/main/scala/com/databricks/spark/csv/util/CompressionCodecs.scala new file mode 100644 index 0000000..26dc681 --- /dev/null +++ b/src/main/scala/com/databricks/spark/csv/util/CompressionCodecs.scala @@ -0,0 +1,49 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.spark.csv.util + +import scala.util.control.Exception._ + +import org.apache.hadoop.io.compress._ + +private[csv] object CompressionCodecs { + private val shortCompressionCodecNames: Map[String, String] = { + val codecMap = collection.mutable.Map.empty[String, String] + allCatch toTry(codecMap += "bzip2" -> classOf[BZip2Codec].getName) + allCatch toTry(codecMap += "gzip" -> classOf[GzipCodec].getName) + allCatch toTry(codecMap += "lz4" -> classOf[Lz4Codec].getName) + allCatch toTry(codecMap += "snappy" -> classOf[SnappyCodec].getName) + codecMap.toMap + } + + /** + * Return the codec class of the given name. + */ + def getCodecClass: String => Class[_ <: CompressionCodec] = { + case null => null + case codec => + val codecName = shortCompressionCodecNames.getOrElse(codec.toLowerCase, codec) + try { + // scalastyle:off classforname + Class.forName(codecName).asInstanceOf[Class[CompressionCodec]] + // scalastyle:on classforname + } catch { + case e: ClassNotFoundException => + throw new IllegalArgumentException(s"Codec [$codecName] is not " + + s"available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") + } + } +} diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 6fddab5..133f6ea 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -473,6 +473,24 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll { assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) } + test("Scala API save with gzip compression codec by shorten name") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = sqlContext.csvFile(carsFile, parserLib = parserLib) + cars.save("com.databricks.spark.csv", SaveMode.Overwrite, + Map("path" -> copyFilePath, "header" -> "true", "codec" -> "gZiP")) + val carsCopyPartFile = new File(copyFilePath, "part-00000.gz") + // Check that the part file has a .gz extension + assert(carsCopyPartFile.exists()) + + val carsCopy = sqlContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } test("DSL save with quoting") { // Create temp directory diff --git a/src/test/scala/com/databricks/spark/csv/util/CompressionCodecsSuite.scala b/src/test/scala/com/databricks/spark/csv/util/CompressionCodecsSuite.scala new file mode 100644 index 0000000..f9bdc79 --- /dev/null +++ b/src/test/scala/com/databricks/spark/csv/util/CompressionCodecsSuite.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.spark.csv.util + +import org.apache.hadoop.io.compress._ +import org.scalatest.FunSuite + +class CompressionCodecsSuite extends FunSuite { + + /** + * Note that Lz4 codec was added from Hadoop 2.x. So, some tests might fail with + * class-not-found exception when Hadoop version is lower. + */ + test("Get classes of compression codecs") { + assert(CompressionCodecs.getCodecClass(classOf[GzipCodec].getName) == classOf[GzipCodec]) + assert(CompressionCodecs.getCodecClass(classOf[SnappyCodec].getName) == classOf[SnappyCodec]) + assert(CompressionCodecs.getCodecClass(classOf[Lz4Codec].getName) == classOf[Lz4Codec]) + assert(CompressionCodecs.getCodecClass(classOf[BZip2Codec].getName) == classOf[BZip2Codec]) + } + + test("Get classes of compression codecs with short names") { + assert(CompressionCodecs.getCodecClass("GzIp") == classOf[GzipCodec]) + assert(CompressionCodecs.getCodecClass("Snappy") == classOf[SnappyCodec]) + assert(CompressionCodecs.getCodecClass("lz4") == classOf[Lz4Codec]) + assert(CompressionCodecs.getCodecClass("bZip2") == classOf[BZip2Codec]) + } +} diff --git a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala index d713649..41f3ba8 100644 --- a/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala @@ -1,3 +1,18 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.databricks.spark.csv.util import org.apache.spark.sql.types._