Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.Charset

import org.apache.hadoop.io.compress._

import org.apache.spark.Logging
import org.apache.spark.util.Utils

private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging {

Expand All @@ -35,7 +38,7 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String]

private def getBool(paramName: String, default: Boolean = false): Boolean = {
val param = parameters.getOrElse(paramName, default.toString)
if (param.toLowerCase() == "true") {
if (param.toLowerCase == "true") {
true
} else if (param.toLowerCase == "false") {
false
Expand Down Expand Up @@ -73,6 +76,11 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String]

val nullValue = parameters.getOrElse("nullValue", "")

val compressionCodec: Option[String] = {
val name = parameters.get("compression").orElse(parameters.get("codec"))
name.map(CSVCompressionCodecs.getCodecClassName)
}

val maxColumns = 20480

val maxCharsPerColumn = 100000
Expand All @@ -85,7 +93,6 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String]
}

private[csv] object ParseModes {

val PERMISSIVE_MODE = "PERMISSIVE"
val DROP_MALFORMED_MODE = "DROPMALFORMED"
val FAIL_FAST_MODE = "FAILFAST"
Expand All @@ -107,3 +114,28 @@ private[csv] object ParseModes {
true // We default to permissive is the mode string is not valid
}
}

private[csv] object CSVCompressionCodecs {
private val shortCompressionCodecNames = Map(
"bzip2" -> classOf[BZip2Codec].getName,
"gzip" -> classOf[GzipCodec].getName,
"lz4" -> classOf[Lz4Codec].getName,
"snappy" -> classOf[SnappyCodec].getName)

/**
* Return the full version of the given codec class.
* If it is already a class name, just return it.
*/
def getCodecClassName(name: String): String = {
val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name)
try {
// Validate the codec name
Utils.classForName(codecName)
codecName
} catch {
case e: ClassNotFoundException =>
throw new IllegalArgumentException(s"Codec [$codecName] " +
s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.util.control.NonFatal
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.hadoop.mapreduce.RecordWriter
Expand Down Expand Up @@ -99,6 +100,15 @@ private[csv] class CSVRelation(
}

override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val conf = job.getConfiguration
params.compressionCodec.foreach { codec =>
conf.set("mapreduce.output.fileoutputformat.compress", "true")
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
conf.set("mapreduce.map.output.compress", "true")
conf.set("mapreduce.map.output.compress.codec", codec)
}

new CSVOutputWriterFactory(params)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null"))
assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
}

test("save csv with compression codec option") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
val cars = sqlContext.read
.format("csv")
.option("header", "true")
.load(testFile(carsFile))

cars.coalesce(1).write
.format("csv")
.option("header", "true")
.option("compression", "gZiP")
.save(csvDir)

val compressedFiles = new File(csvDir).listFiles()
assert(compressedFiles.exists(_.getName.endsWith(".gz")))

val carsCopy = sqlContext.read
.format("csv")
.option("header", "true")
.load(csvDir)

verifyCars(carsCopy, withHeader = true)
}
}
}