Skip to content

Commit 4b999fa

Browse files
committed
Add the support for compression in TEXT datasource
1 parent 1b39faf commit 4b999fa

File tree

5 files changed

+57
-27
lines changed

5 files changed

+57
-27
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20+
import org.apache.hadoop.conf.Configuration
21+
import org.apache.hadoop.io.SequenceFile.CompressionType
2022
import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec}
2123

2224
import org.apache.spark.util.Utils
@@ -44,4 +46,16 @@ private[datasources] object CompressionCodecs {
4446
s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
4547
}
4648
}
49+
50+
/**
51+
* Set compression configurations to Hadoop `Configuration`.
52+
* `codec` should be a full class path
53+
*/
54+
def setCodecConfiguration(conf: Configuration, codec: String): Unit = {
55+
conf.set("mapreduce.output.fileoutputformat.compress", "true")
56+
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
57+
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
58+
conf.set("mapreduce.map.output.compress", "true")
59+
conf.set("mapreduce.map.output.compress.codec", codec)
60+
}
4761
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import scala.util.control.NonFatal
2424
import com.google.common.base.Objects
2525
import org.apache.hadoop.fs.{FileStatus, Path}
2626
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
27-
import org.apache.hadoop.io.SequenceFile.CompressionType
2827
import org.apache.hadoop.mapred.TextInputFormat
2928
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
3029
import org.apache.hadoop.mapreduce.RecordWriter
@@ -34,6 +33,7 @@ import org.apache.spark.Logging
3433
import org.apache.spark.rdd.RDD
3534
import org.apache.spark.sql._
3635
import org.apache.spark.sql.catalyst.InternalRow
36+
import org.apache.spark.sql.execution.datasources.CompressionCodecs
3737
import org.apache.spark.sql.sources._
3838
import org.apache.spark.sql.types._
3939

@@ -50,16 +50,16 @@ private[sql] class CSVRelation(
5050
case None => inferSchema(paths)
5151
}
5252

53-
private val params = new CSVOptions(parameters)
53+
private val options = new CSVOptions(parameters)
5454

5555
@transient
5656
private var cachedRDD: Option[RDD[String]] = None
5757

5858
private def readText(location: String): RDD[String] = {
59-
if (Charset.forName(params.charset) == Charset.forName("UTF-8")) {
59+
if (Charset.forName(options.charset) == Charset.forName("UTF-8")) {
6060
sqlContext.sparkContext.textFile(location)
6161
} else {
62-
val charset = params.charset
62+
val charset = options.charset
6363
sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location)
6464
.mapPartitions { _.map { pair =>
6565
new String(pair._2.getBytes, 0, pair._2.getLength, charset)
@@ -81,8 +81,8 @@ private[sql] class CSVRelation(
8181
private def tokenRdd(header: Array[String], inputPaths: Array[String]): RDD[Array[String]] = {
8282
val rdd = baseRdd(inputPaths)
8383
// Make sure firstLine is materialized before sending to executors
84-
val firstLine = if (params.headerFlag) findFirstLine(rdd) else null
85-
CSVRelation.univocityTokenizer(rdd, header, firstLine, params)
84+
val firstLine = if (options.headerFlag) findFirstLine(rdd) else null
85+
CSVRelation.univocityTokenizer(rdd, header, firstLine, options)
8686
}
8787

8888
/**
@@ -96,20 +96,16 @@ private[sql] class CSVRelation(
9696
val pathsString = inputs.map(_.getPath.toUri.toString)
9797
val header = schema.fields.map(_.name)
9898
val tokenizedRdd = tokenRdd(header, pathsString)
99-
CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, params)
99+
CSVRelation.parseCsv(tokenizedRdd, schema, requiredColumns, inputs, sqlContext, options)
100100
}
101101

102102
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
103103
val conf = job.getConfiguration
104-
params.compressionCodec.foreach { codec =>
105-
conf.set("mapreduce.output.fileoutputformat.compress", "true")
106-
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
107-
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
108-
conf.set("mapreduce.map.output.compress", "true")
109-
conf.set("mapreduce.map.output.compress.codec", codec)
104+
options.compressionCodec.foreach { codec =>
105+
CompressionCodecs.setCodecConfiguration(conf, codec)
110106
}
111107

112-
new CSVOutputWriterFactory(params)
108+
new CSVOutputWriterFactory(options)
113109
}
114110

115111
override def hashCode(): Int = Objects.hashCode(paths.toSet, dataSchema, schema, partitionColumns)
@@ -129,17 +125,17 @@ private[sql] class CSVRelation(
129125
private def inferSchema(paths: Array[String]): StructType = {
130126
val rdd = baseRdd(paths)
131127
val firstLine = findFirstLine(rdd)
132-
val firstRow = new LineCsvReader(params).parseLine(firstLine)
128+
val firstRow = new LineCsvReader(options).parseLine(firstLine)
133129

134-
val header = if (params.headerFlag) {
130+
val header = if (options.headerFlag) {
135131
firstRow
136132
} else {
137133
firstRow.zipWithIndex.map { case (value, index) => s"C$index" }
138134
}
139135

140136
val parsedRdd = tokenRdd(header, paths)
141-
if (params.inferSchemaFlag) {
142-
CSVInferSchema.infer(parsedRdd, header, params.nullValue)
137+
if (options.inferSchemaFlag) {
138+
CSVInferSchema.infer(parsedRdd, header, options.nullValue)
143139
} else {
144140
// By default fields are assumed to be StringType
145141
val schemaFields = header.map { fieldName =>
@@ -153,8 +149,8 @@ private[sql] class CSVRelation(
153149
* Returns the first line of the first non-empty file in path
154150
*/
155151
private def findFirstLine(rdd: RDD[String]): String = {
156-
if (params.isCommentSet) {
157-
val comment = params.comment.toString
152+
if (options.isCommentSet) {
153+
val comment = options.comment.toString
158154
rdd.filter { line =>
159155
line.trim.nonEmpty && !line.startsWith(comment)
160156
}.first()

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,7 @@ private[sql] class JSONRelation(
165165
override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = {
166166
val conf = job.getConfiguration
167167
options.compressionCodec.foreach { codec =>
168-
conf.set("mapreduce.output.fileoutputformat.compress", "true")
169-
conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
170-
conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
171-
conf.set("mapreduce.map.output.compress", "true")
172-
conf.set("mapreduce.map.output.compress.codec", codec)
168+
CompressionCodecs.setCodecConfiguration(conf, codec)
173169
}
174170

175171
new BucketedOutputWriterFactory {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3333
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
34-
import org.apache.spark.sql.execution.datasources.PartitionSpec
34+
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, PartitionSpec}
3535
import org.apache.spark.sql.sources._
3636
import org.apache.spark.sql.types.{StringType, StructType}
3737
import org.apache.spark.util.SerializableConfiguration
@@ -48,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
4848
partitionColumns: Option[StructType],
4949
parameters: Map[String, String]): HadoopFsRelation = {
5050
dataSchema.foreach(verifySchema)
51-
new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext)
51+
new TextRelation(None, dataSchema, partitionColumns, paths, parameters)(sqlContext)
5252
}
5353

5454
override def shortName(): String = "text"
@@ -114,6 +114,15 @@ private[sql] class TextRelation(
114114

115115
/** Write path. */
116116
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
117+
val conf = job.getConfiguration
118+
val compressionCodec = {
119+
val name = parameters.get("compression").orElse(parameters.get("codec"))
120+
name.map(CompressionCodecs.getCodecClassName)
121+
}
122+
compressionCodec.foreach { codec =>
123+
CompressionCodecs.setCodecConfiguration(conf, codec)
124+
}
125+
117126
new OutputWriterFactory {
118127
override def newInstance(
119128
path: String,

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ class TextSuite extends QueryTest with SharedSQLContext {
5757
}
5858
}
5959

60+
test("SPARK-13503 Support to specify the option for compression codec for TEXT") {
61+
val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")
62+
63+
val tempFile = Utils.createTempDir()
64+
tempFile.delete()
65+
df.write
66+
.option("compression", "gZiP")
67+
.text(tempFile.getCanonicalPath)
68+
val compressedFiles = tempFile.listFiles()
69+
assert(compressedFiles.exists(_.getName.endsWith(".gz")))
70+
verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))
71+
72+
Utils.deleteRecursively(tempFile)
73+
}
74+
6075
private def testFile: String = {
6176
Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
6277
}

0 commit comments

Comments
 (0)