Skip to content

Commit 25e9250

Browse files
committed
Add a common utility code to map short names to fully-qualified codec names
1 parent f77dc4e commit 25e9250

File tree

6 files changed

+106
-57
lines changed

6 files changed

+106
-57
lines changed

core/src/main/scala/org/apache/spark/io/CompressionCodec.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}
2525

2626
import org.apache.spark.SparkConf
2727
import org.apache.spark.annotation.DeveloperApi
28-
import org.apache.spark.util.Utils
28+
import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils}
2929

3030
/**
3131
* :: DeveloperApi ::
@@ -53,10 +53,14 @@ private[spark] object CompressionCodec {
5353
|| codec.isInstanceOf[LZ4CompressionCodec])
5454
}
5555

56-
private val shortCompressionCodecNames = Map(
57-
"lz4" -> classOf[LZ4CompressionCodec].getName,
58-
"lzf" -> classOf[LZFCompressionCodec].getName,
59-
"snappy" -> classOf[SnappyCompressionCodec].getName)
56+
/** Maps the short versions of compression codec names to fully-qualified class names. */
57+
private val shortCompressionCodecNameMapper = new ShortCompressionCodecNameMapper {
58+
override def lz4: Option[String] = Some(classOf[LZ4CompressionCodec].getName)
59+
override def lzf: Option[String] = Some(classOf[LZFCompressionCodec].getName)
60+
override def snappy: Option[String] = Some(classOf[SnappyCompressionCodec].getName)
61+
}
62+
63+
private val shortCompressionCodecMap = shortCompressionCodecNameMapper.getAsMap
6064

6165
def getCodecName(conf: SparkConf): String = {
6266
conf.get(configKey, DEFAULT_COMPRESSION_CODEC)
@@ -67,7 +71,7 @@ private[spark] object CompressionCodec {
6771
}
6872

6973
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
70-
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
74+
val codecClass = shortCompressionCodecNameMapper.get(codecName).getOrElse(codecName)
7175
val codec = try {
7276
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
7377
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
@@ -84,18 +88,18 @@ private[spark] object CompressionCodec {
8488
* If it is already a short name, just return it.
8589
*/
8690
def getShortName(codecName: String): String = {
87-
if (shortCompressionCodecNames.contains(codecName)) {
91+
if (shortCompressionCodecMap.contains(codecName)) {
8892
codecName
8993
} else {
90-
shortCompressionCodecNames
94+
shortCompressionCodecMap
9195
.collectFirst { case (k, v) if v == codecName => k }
9296
.getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") }
9397
}
9498
}
9599

96100
val FALLBACK_COMPRESSION_CODEC = "snappy"
97101
val DEFAULT_COMPRESSION_CODEC = "lz4"
98-
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
102+
val ALL_COMPRESSION_CODECS = shortCompressionCodecMap.values.toSeq
99103
}
100104

101105
/**

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,51 @@ private[spark] object CallSite {
6060
val empty = CallSite("", "")
6161
}
6262

63+
/** An utility class to map short compression codec names to qualified ones. */
64+
private[spark] class ShortCompressionCodecNameMapper {
65+
66+
def get(codecName: String): Option[String] = codecName.toLowerCase match {
67+
case "none" => none
68+
case "uncompressed" => uncompressed
69+
case "bzip2" => bzip2
70+
case "deflate" => deflate
71+
case "gzip" => gzip
72+
case "lzo" => lzo
73+
case "lz4" => lz4
74+
case "lzf" => lzf
75+
case "snappy" => snappy
76+
case _ => None
77+
}
78+
79+
def getAsMap: Map[String, String] = {
80+
Seq(
81+
("none", none),
82+
("uncompressed", uncompressed),
83+
("bzip2", bzip2),
84+
("deflate", deflate),
85+
("gzip", gzip),
86+
("lzo", lzo),
87+
("lz4", lz4),
88+
("lzf", lzf),
89+
("snappy", snappy)
90+
).flatMap { case (shortCodecName, codecName) =>
91+
if (codecName.isDefined) Some(shortCodecName, codecName.get) else None
92+
}.toMap
93+
}
94+
95+
// To support short codec names, derived classes need to override the methods below that return
96+
// corresponding qualified codec names.
97+
def none: Option[String] = None
98+
def uncompressed: Option[String] = None
99+
def bzip2: Option[String] = None
100+
def deflate: Option[String] = None
101+
def gzip: Option[String] = None
102+
def lzo: Option[String] = None
103+
def lz4: Option[String] = None
104+
def lzf: Option[String] = None
105+
def snappy: Option[String] = None
106+
}
107+
63108
/**
64109
* Various utility methods used by Spark.
65110
*/

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
/*
2-
* Licensed to the Apache Software Foundation (ASF) under one or more
3-
* contributor license agreements. See the NOTICE file distributed with
4-
* this work for additional information regarding copyright ownership.
5-
* The ASF licenses this file to You under the Apache License, Version 2.0
6-
* (the "License"); you may not use this file except in compliance with
7-
* the License. You may obtain a copy of the License at
8-
*
9-
* http://www.apache.org/licenses/LICENSE-2.0
10-
*
11-
* Unless required by applicable law or agreed to in writing, software
12-
* distributed under the License is distributed on an "AS IS" BASIS,
13-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14-
* See the License for the specific language governing permissions and
15-
* limitations under the License.
16-
*/
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
1717

1818
package org.apache.spark.sql
1919

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CompressionCodecs.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,37 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import org.apache.hadoop.conf.Configuration
2121
import org.apache.hadoop.io.SequenceFile.CompressionType
22-
import org.apache.hadoop.io.compress.{BZip2Codec, GzipCodec, Lz4Codec, SnappyCodec}
22+
import org.apache.hadoop.io.compress.{BZip2Codec, DeflateCodec, GzipCodec, Lz4Codec, SnappyCodec}
23+
import org.apache.spark.util.ShortCompressionCodecNameMapper
2324

2425
import org.apache.spark.util.Utils
2526

2627
private[datasources] object CompressionCodecs {
27-
private val shortCompressionCodecNames = Map(
28-
"bzip2" -> classOf[BZip2Codec].getName,
29-
"gzip" -> classOf[GzipCodec].getName,
30-
"lz4" -> classOf[Lz4Codec].getName,
31-
"snappy" -> classOf[SnappyCodec].getName)
28+
29+
/** Maps the short versions of compression codec names to fully-qualified class names. */
30+
private val hadoopShortCodecNameMapper = new ShortCompressionCodecNameMapper {
31+
override def bzip2: Option[String] = Some(classOf[BZip2Codec].getCanonicalName)
32+
override def deflate: Option[String] = Some(classOf[DeflateCodec].getCanonicalName)
33+
override def gzip: Option[String] = Some(classOf[GzipCodec].getCanonicalName)
34+
override def lz4: Option[String] = Some(classOf[Lz4Codec].getCanonicalName)
35+
override def snappy: Option[String] = Some(classOf[SnappyCodec].getCanonicalName)
36+
}
3237

3338
/**
3439
* Return the full version of the given codec class.
3540
* If it is already a class name, just return it.
3641
*/
3742
def getCodecClassName(name: String): String = {
38-
val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name)
43+
val codecName = hadoopShortCodecNameMapper.get(name).getOrElse(name)
3944
try {
4045
// Validate the codec name
4146
Utils.classForName(codecName)
4247
codecName
4348
} catch {
4449
case e: ClassNotFoundException =>
4550
throw new IllegalArgumentException(s"Codec [$codecName] " +
46-
s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.")
51+
s"is not available. Known codecs are " +
52+
s"${hadoopShortCodecNameMapper.getAsMap.keys.mkString(", ")}.")
4753
}
4854
}
4955

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.datasources.{PartitionSpec, _}
5050
import org.apache.spark.sql.internal.SQLConf
5151
import org.apache.spark.sql.sources._
5252
import org.apache.spark.sql.types.{DataType, StructType}
53-
import org.apache.spark.util.{SerializableConfiguration, Utils}
53+
import org.apache.spark.util.{SerializableConfiguration, ShortCompressionCodecNameMapper, Utils}
5454

5555
private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {
5656

@@ -284,10 +284,8 @@ private[sql] class ParquetRelation(
284284
conf.set(
285285
ParquetOutputFormat.COMPRESSION,
286286
ParquetRelation
287-
.shortParquetCompressionCodecNames
288-
.getOrElse(
289-
sqlContext.conf.parquetCompressionCodec.toUpperCase,
290-
CompressionCodecName.UNCOMPRESSED).name())
287+
.parquetShortCodecNameMapper.get(sqlContext.conf.parquetCompressionCodec)
288+
.getOrElse(CompressionCodecName.UNCOMPRESSED.name()))
291289

292290
new BucketedOutputWriterFactory {
293291
override def newInstance(
@@ -903,11 +901,12 @@ private[sql] object ParquetRelation extends Logging {
903901
}
904902
}
905903

906-
// The parquet compression short names
907-
val shortParquetCompressionCodecNames = Map(
908-
"NONE" -> CompressionCodecName.UNCOMPRESSED,
909-
"UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED,
910-
"SNAPPY" -> CompressionCodecName.SNAPPY,
911-
"GZIP" -> CompressionCodecName.GZIP,
912-
"LZO" -> CompressionCodecName.LZO)
904+
/** Maps the short versions of compression codec names to qualified compression names. */
905+
val parquetShortCodecNameMapper = new ShortCompressionCodecNameMapper {
906+
override def none: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
907+
override def uncompressed: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
908+
override def gzip: Option[String] = Some(CompressionCodecName.GZIP.name())
909+
override def lzo: Option[String] = Some(CompressionCodecName.LZO.name())
910+
override def snappy: Option[String] = Some(CompressionCodecName.SNAPPY.name())
911+
}
913912
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.text
1919

20-
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
20+
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode}
2121
import org.apache.spark.sql.test.SharedSQLContext
2222
import org.apache.spark.sql.types.{StringType, StructType}
2323
import org.apache.spark.util.Utils
@@ -58,18 +58,13 @@ class TextSuite extends QueryTest with SharedSQLContext {
5858
}
5959

6060
test("SPARK-13503 Support to specify the option for compression codec for TEXT") {
61-
val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")
62-
63-
val tempFile = Utils.createTempDir()
64-
tempFile.delete()
65-
df.write
66-
.option("compression", "gZiP")
67-
.text(tempFile.getCanonicalPath)
68-
val compressedFiles = tempFile.listFiles()
69-
assert(compressedFiles.exists(_.getName.endsWith(".gz")))
70-
verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))
71-
72-
Utils.deleteRecursively(tempFile)
61+
Seq("bzip2", "deflate", "gzip").map { codecName =>
62+
val tempDir = Utils.createTempDir()
63+
val tempDirPath = tempDir.getAbsolutePath()
64+
val df = sqlContext.read.text(testFile)
65+
df.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
66+
verifyFrame(sqlContext.read.text(tempDirPath))
67+
}
7368
}
7469

7570
private def testFile: String = {

0 commit comments

Comments
 (0)