From d3cddb228f20b738b76c0fd74f6891dc9beaedb8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Feb 2016 12:52:20 +0900 Subject: [PATCH 1/7] Validate ascii compatible encodings --- .../datasources/csv/CSVOptions.scala | 20 ++++++++++++++++--- .../execution/datasources/csv/CSVSuite.scala | 12 +++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 709daccbbef58..090759c4a7c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.Charset +import java.nio.charset.{Charset, UnsupportedCharsetException} import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs @@ -47,11 +47,25 @@ private[sql] class CSVOptions( } } + private def checkedCharset(charsetName: String): String = { + val charset = Charset.forName(charsetName) + val lineSeq = "\n" + // Currently this datasource does not support non-ascii compatible encodings. See SPARK-13108 + val isASCIICompatible = + java.util.Arrays.equals( + lineSeq.getBytes(Charset.forName("UTF-8")), lineSeq.getBytes(charset)) + if (!isASCIICompatible) { + throw new UnsupportedCharsetException(charsetName) + } else { + charsetName + } + } + val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = parameters.getOrElse("encoding", - parameters.getOrElse("charset", Charset.forName("UTF-8").name())) + val charset = checkedCharset( parameters.getOrElse("encoding", + parameters.getOrElse("charset", Charset.forName("UTF-8").name()))) val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a79566b1f3658..83c94b0a39e19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -122,6 +122,18 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(exception.getMessage.contains("1-9588-osi")) } + test("non-ascii compatible encoding name") { + val exception = intercept[UnsupportedCharsetException] { + sqlContext + .read + .format("csv") + .option("charset", "UTF-16") + .load(testFile(carsFile8859)) + } + + assert(exception.getMessage.contains("UTF-16")) + } + test("test different encoding") { // scalastyle:off sqlContext.sql( From 9f3735ca31f323fa460fd2f386fcef99725d890b Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 2 Feb 2016 12:53:33 +0900 Subject: [PATCH 2/7] Remove a heading extra whitespace --- .../apache/spark/sql/execution/datasources/csv/CSVOptions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 090759c4a7c92..e5f7f1554e192 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -64,7 +64,7 @@ private[sql] class CSVOptions( val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = checkedCharset( parameters.getOrElse("encoding", + val charset = checkedCharset(parameters.getOrElse("encoding", parameters.getOrElse("charset", Charset.forName("UTF-8").name()))) val quote = getChar("quote", '\"') From 34af8d371b4b8b88f8f441150185ad0ea618a9cc Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 4 Feb 2016 15:53:56 +0900 Subject: [PATCH 3/7] Add the support for non-ascii compatible encodings --- .../datasources/csv/CSVOptions.scala | 19 +--- .../datasources/csv/CSVRelation.scala | 88 ++++++++++++++++-- sql/core/src/test/resources/cars_utf-16.csv | Bin 0 -> 268 bytes .../execution/datasources/csv/CSVSuite.scala | 18 ++-- 4 files changed, 93 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/test/resources/cars_utf-16.csv diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e5f7f1554e192..f8e0f2bb8104e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -46,26 +46,11 @@ private[sql] class CSVOptions( throw new Exception(s"$paramName flag can be true or false") } } - - private def checkedCharset(charsetName: String): String = { - val charset = Charset.forName(charsetName) - val lineSeq = "\n" - // Currently this datasource does not support non-ascii compatible encodings. See SPARK-13108 - val isASCIICompatible = - java.util.Arrays.equals( - lineSeq.getBytes(Charset.forName("UTF-8")), lineSeq.getBytes(charset)) - if (!isASCIICompatible) { - throw new UnsupportedCharsetException(charsetName) - } else { - charsetName - } - } - val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val charset = checkedCharset(parameters.getOrElse("encoding", - parameters.getOrElse("charset", Charset.forName("UTF-8").name()))) + val charset = parameters.getOrElse("encoding", + parameters.getOrElse("charset", Charset.forName("UTF-8").name())) val quote = getChar("quote", '\"') val escape = getChar("escape", '\\') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index dc449fea956f8..e97f11ff20e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -22,12 +22,12 @@ import java.nio.charset.Charset import scala.util.control.NonFatal import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.hadoop.mapreduce.RecordWriter +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{LineRecordReader, TextInputFormat} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.Logging @@ -60,7 +60,10 @@ private[csv] class CSVRelation( sqlContext.sparkContext.textFile(location) } else { val charset = params.charset - sqlContext.sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](location) + val conf = sqlContext.sparkContext.hadoopConfiguration + conf.set(EncodingTextInputFormat.ENCODING_KEY, charset) + sqlContext.sparkContext + .newAPIHadoopFile[LongWritable, Text, EncodingTextInputFormat](location) .mapPartitions { _.map { pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset) } @@ -249,16 +252,89 @@ object CSVRelation extends Logging { } } +/** + * Because `TextInputFormat` in Hadoop does not support non-ascii compatible encodings, + * We need another `InputFormat` to handle the encodings. See SPARK-13108. + */ +private[csv] class EncodingTextInputFormat extends TextInputFormat { + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[LongWritable, Text] = { + val conf: Configuration = { + // Use reflection to get the Configuration. This is necessary because TaskAttemptContext is + // a class in Hadoop 1.x and an interface in Hadoop 2.x. + val method = context.getClass.getMethod("getConfiguration") + method.invoke(context).asInstanceOf[Configuration] + } + val charset = Charset.forName(conf.get(EncodingTextInputFormat.ENCODING_KEY, "UTF-8")) + val charsetName = charset.name + val safeRecordDelimiterBytes = { + val delimiter = "\n" + val recordDelimiterBytes = delimiter.getBytes(charset) + EncodingTextInputFormat.stripBOM(charsetName, recordDelimiterBytes) + } + new LineRecordReader(safeRecordDelimiterBytes) { + var isFirst = true + override def getCurrentValue: Text = { + val value = super.getCurrentValue + if (isFirst) { + isFirst = false + val safeBytes = EncodingTextInputFormat.stripBOM(charsetName, value.getBytes) + new Text(safeBytes) + } else { + value + } + } + } + } +} + +private[csv] object EncodingTextInputFormat { + // configuration key for encoding type + val ENCODING_KEY = "encodinginputformat.encoding" + + def stripBOM(charsetName: String, bytes: Array[Byte]): Array[Byte] = { + charsetName match { + case "UTF-8" + if bytes(0) == 0xEF.toByte && + bytes(1) == 0xBB.toByte && + bytes(2) == 0xBF.toByte => + bytes.slice(3, bytes.length) + case "UTF-16" | "UTF-16BE" + if bytes(0) == 0xFE.toByte && + bytes(1) == 0xFF.toByte => + bytes.slice(2, bytes.length) + case "UTF-16LE" + if bytes(0) == 0xFF.toByte && + bytes(1) == 0xFE.toByte => + bytes.slice(2, bytes.length) + case "UTF-32" | "UTF-32BE" + if bytes(0) == 0x00.toByte && + bytes(1) == 0x00.toByte && + bytes(2) == 0xFE.toByte && + bytes(3) == 0xFF.toByte => + bytes.slice(4, bytes.length) + case "UTF-32LE" + if bytes(0) == 0xFF.toByte && + bytes(1) == 0xFE.toByte && + bytes(2) == 0x00.toByte && + bytes(3) == 0x00.toByte => + bytes.slice(4, bytes.length) + case _ => bytes + } + } +} + private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new CsvOutputWriter(path, dataSchema, context, params) + new CSVOutputWriter(path, dataSchema, context, params) } } -private[sql] class CsvOutputWriter( +private[sql] class CSVOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, diff --git a/sql/core/src/test/resources/cars_utf-16.csv b/sql/core/src/test/resources/cars_utf-16.csv new file mode 100644 index 0000000000000000000000000000000000000000..a94ed8cb14be5833e382ff20358e65d9b080bbc8 GIT binary patch literal 268 zcmZ9Hy$ZrW5QL|`r;yZF7$u0cXgV9kKI$1Ie~=tXUS9n!XC(`_Gdr`h$@`;GPKA0| zHS`J=P^T-X24BDp< Date: Thu, 4 Feb 2016 15:57:42 +0900 Subject: [PATCH 4/7] Remove unused import and add a newline --- .../spark/sql/execution/datasources/csv/CSVOptions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index f8e0f2bb8104e..709daccbbef58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.{Charset, UnsupportedCharsetException} +import java.nio.charset.Charset import org.apache.spark.Logging import org.apache.spark.sql.execution.datasources.CompressionCodecs @@ -46,6 +46,7 @@ private[sql] class CSVOptions( throw new Exception(s"$paramName flag can be true or false") } } + val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode = parameters.getOrElse("mode", "PERMISSIVE") From 875e237b76a2f08ebd4fed2af82e63325965b608 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 25 Feb 2016 10:41:02 +0900 Subject: [PATCH 5/7] Shorten codes for the same logics --- .../datasources/csv/CSVRelation.scala | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index e97f11ff20e75..c3366117e5be7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -292,34 +292,26 @@ private[csv] class EncodingTextInputFormat extends TextInputFormat { private[csv] object EncodingTextInputFormat { // configuration key for encoding type val ENCODING_KEY = "encodinginputformat.encoding" + // BOM bytes for UTF-8, UTF-16 and UTF-32 + private val utf8BOM = Array(0xEF.toByte, 0xBB.toByte, 0xBF.toByte) + private val utf16beBOM = Array(0xFE.toByte, 0xFF.toByte) + private val utf16leBOM = Array(0xFF.toByte, 0xFE.toByte) + private val utf32beBOM = Array(0x00.toByte, 0x00.toByte, 0xFE.toByte, 0xFF.toByte) + private val utf32leBOM = Array(0xFF.toByte, 0xFE.toByte, 0x00.toByte, 0x00.toByte) def stripBOM(charsetName: String, bytes: Array[Byte]): Array[Byte] = { charsetName match { - case "UTF-8" - if bytes(0) == 0xEF.toByte && - bytes(1) == 0xBB.toByte && - bytes(2) == 0xBF.toByte => - bytes.slice(3, bytes.length) - case "UTF-16" | "UTF-16BE" - if bytes(0) == 0xFE.toByte && - bytes(1) == 0xFF.toByte => + case "UTF-8" if bytes.startsWith(utf8BOM) => + bytes.slice(utf8BOM.length, bytes.length) + case "UTF-16" | "UTF-16BE" if bytes.startsWith(utf16beBOM) => + bytes.slice(utf16beBOM.length, bytes.length) + case "UTF-16LE" if bytes.startsWith(utf16leBOM) => + bytes.slice(utf16leBOM.length, bytes.length) bytes.slice(2, bytes.length) - case "UTF-16LE" - if bytes(0) == 0xFF.toByte && - bytes(1) == 0xFE.toByte => - bytes.slice(2, bytes.length) - case "UTF-32" | "UTF-32BE" - if bytes(0) == 0x00.toByte && - bytes(1) == 0x00.toByte && - bytes(2) == 0xFE.toByte && - bytes(3) == 0xFF.toByte => - bytes.slice(4, bytes.length) - case "UTF-32LE" - if bytes(0) == 0xFF.toByte && - bytes(1) == 0xFE.toByte && - bytes(2) == 0x00.toByte && - bytes(3) == 0x00.toByte => - bytes.slice(4, bytes.length) + case "UTF-32" | "UTF-32BE" if bytes.startsWith(utf32beBOM) => + bytes.slice(utf32beBOM.length, bytes.length) + case "UTF-32LE" if bytes.startsWith(utf32leBOM) => + bytes.slice(utf32leBOM.length, bytes.length) case _ => bytes } } From 4fa2ed30ff4a560d74b7f4858d30377c2cccfcd6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 25 Feb 2016 10:46:38 +0900 Subject: [PATCH 6/7] Remove duplicated line --- .../apache/spark/sql/execution/datasources/csv/CSVRelation.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index c3366117e5be7..aab0ebd673cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -307,7 +307,6 @@ private[csv] object EncodingTextInputFormat { bytes.slice(utf16beBOM.length, bytes.length) case "UTF-16LE" if bytes.startsWith(utf16leBOM) => bytes.slice(utf16leBOM.length, bytes.length) - bytes.slice(2, bytes.length) case "UTF-32" | "UTF-32BE" if bytes.startsWith(utf32beBOM) => bytes.slice(utf32beBOM.length, bytes.length) case "UTF-32LE" if bytes.startsWith(utf32leBOM) => From 264a1dc603164bd264e0c084608f31ffb8ad5f69 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 8 Mar 2016 15:36:10 +0900 Subject: [PATCH 7/7] Call EncodingTextInputFormat --- .../spark/sql/execution/datasources/csv/DefaultSource.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index aff672281d640..9d7b21ae026c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -165,8 +165,10 @@ class DefaultSource extends FileFormat with DataSourceRegister { sqlContext.sparkContext.textFile(location) } else { val charset = options.charset + val conf = sqlContext.sparkContext.hadoopConfiguration + conf.set(EncodingTextInputFormat.ENCODING_KEY, charset) sqlContext.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](location) + .newAPIHadoopFile[LongWritable, Text, EncodingTextInputFormat](location) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) } }