Skip to content

Commit ff339a5

Browse files
committed
add EscapedTextInputFormat
1 parent fb0db77 commit ff339a5

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.input
19+
20+
import java.io.{BufferedReader, IOException, InputStreamReader}
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
24+
import org.apache.hadoop.conf.Configuration
25+
import org.apache.hadoop.fs.FSDataInputStream
26+
import org.apache.hadoop.io.compress.CompressionCodecFactory
27+
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
28+
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
29+
30+
/**
31+
* Input format for text records saved with in-record delimiter and newline characters escaped.
32+
*
33+
* For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|`
34+
* should be the following:
35+
* {{{
36+
* a\\\n|\\|b\\\\\n
37+
* }}},
38+
* where the in-record `|`, `\n`, and `\\` characters are escaped by `\\`.
39+
* Users can configure the delimiter via [[EscapedTextInputFormat$#KEY_DELIMITER]].
40+
* Its default value [[EscapedTextInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD
41+
* with the ESCAPE option:
42+
* {{{
43+
* UNLOAD ('select_statement')
44+
* TO 's3://object_path_prefix'
45+
* ESCAPE
46+
* }}}
47+
*
48+
* @see org.apache.spark.SparkContext#newAPIHadoopFile
49+
*/
50+
class EscapedTextInputFormat extends FileInputFormat[Long, Array[String]] {
51+
52+
override def createRecordReader(
53+
split: InputSplit,
54+
context: TaskAttemptContext): RecordReader[Long, Array[String]] = {
55+
new EscapedTextRecordReader
56+
}
57+
}
58+
59+
object EscapedTextInputFormat {
60+
61+
/** configuration key for delimiter */
62+
val KEY_DELIMITER = "spark.input.escapedText.delimiter"
63+
/** default delimiter */
64+
val DEFAULT_DELIMITER = '|'
65+
66+
/** Gets the delimiter char from conf or the default. */
67+
private[input] def getDelimiterOrDefault(conf: Configuration): Char = {
68+
val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString)
69+
if (c.length != 1) {
70+
throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.")
71+
} else {
72+
c.charAt(0)
73+
}
74+
}
75+
}
76+
77+
private[input] class EscapedTextRecordReader extends RecordReader[Long, Array[String]] {
78+
79+
private var reader: BufferedReader = _
80+
81+
private var key: Long = _
82+
private var value: Array[String] = _
83+
84+
private var start: Long = _
85+
private var end: Long = _
86+
private var cur: Long = _
87+
88+
private var delimiter: Char = _
89+
@inline private[this] final val escapeChar = '\\'
90+
@inline private[this] final val newline = '\n'
91+
92+
@inline private[this] final val defaultBufferSize = 64 * 1024
93+
94+
override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = {
95+
val split = inputSplit.asInstanceOf[FileSplit]
96+
val file = split.getPath
97+
val conf = context.getConfiguration
98+
delimiter = EscapedTextInputFormat.getDelimiterOrDefault(conf)
99+
require(delimiter != escapeChar,
100+
s"The delimiter and the escape char cannot be the same but found $delimiter.")
101+
require(delimiter != newline, "The delimiter cannot be the newline character.")
102+
val compressionCodecs = new CompressionCodecFactory(conf)
103+
val codec = compressionCodecs.getCodec(file)
104+
if (codec != null) {
105+
throw new IOException(s"Do not support compressed files but found $file.")
106+
}
107+
val fs = file.getFileSystem(conf)
108+
val in = fs.open(file)
109+
start = findNext(in, split.getStart)
110+
end = findNext(in, split.getStart + split.getLength)
111+
cur = start
112+
in.seek(cur)
113+
reader = new BufferedReader(new InputStreamReader(in), defaultBufferSize)
114+
}
115+
116+
override def getProgress: Float = {
117+
if (start >= end) {
118+
1.0f
119+
} else {
120+
math.min((cur - start).toFloat / (end - start), 1.0f)
121+
}
122+
}
123+
124+
override def nextKeyValue(): Boolean = {
125+
if (cur < end) {
126+
key = cur
127+
value = nextValue()
128+
true
129+
} else {
130+
false
131+
}
132+
}
133+
134+
override def getCurrentValue: Array[String] = value
135+
136+
override def getCurrentKey: Long = key
137+
138+
override def close(): Unit = {
139+
if (reader != null) {
140+
reader.close()
141+
}
142+
}
143+
144+
/**
145+
* Finds the start of the next record.
146+
* Because we don't know whether the first char is escaped or not, we need to first find a
147+
* position that is not escaped.
148+
* @return the start position of the next record
149+
*/
150+
private def findNext(in: FSDataInputStream, start: Long): Long = {
151+
if (start == 0L) return 0L
152+
var pos = start
153+
in.seek(pos)
154+
val br = new BufferedReader(new InputStreamReader(in), defaultBufferSize)
155+
var escaped = true
156+
var eof = false
157+
while (escaped && !eof) {
158+
val v = br.read()
159+
if (v < 0) {
160+
eof = true
161+
} else {
162+
pos += 1
163+
if (v != escapeChar) {
164+
escaped = false
165+
}
166+
}
167+
}
168+
var newline = false
169+
while ((escaped || !newline) && !eof) {
170+
val v = br.read()
171+
if (v < 0) {
172+
eof = true
173+
} else {
174+
pos += 1
175+
if (v == escapeChar) {
176+
escaped = true
177+
} else {
178+
if (!escaped) {
179+
newline = v == '\n'
180+
} else {
181+
escaped = false
182+
}
183+
}
184+
}
185+
}
186+
pos
187+
}
188+
189+
private def nextValue(): Array[String] = {
190+
var escaped = false
191+
val fields = ArrayBuffer.empty[String]
192+
var endOfRecord = false
193+
var eof = false
194+
while (!endOfRecord && !eof) {
195+
var endOfField = false
196+
val sb = new StringBuilder
197+
while (!endOfField && !endOfRecord && !eof) {
198+
val v = reader.read()
199+
if (v < 0) {
200+
eof = true
201+
} else {
202+
cur += 1
203+
if (escaped) {
204+
if (v != escapeChar && v != delimiter && v != newline) {
205+
throw new IllegalStateException(s"Found ${v.asInstanceOf[Char]} after $escapeChar.")
206+
}
207+
sb.append(v.asInstanceOf[Char])
208+
escaped = false
209+
} else {
210+
if (v == escapeChar) {
211+
escaped = true
212+
} else if (v == delimiter) {
213+
endOfField = true
214+
} else if (v == newline) {
215+
endOfRecord = true
216+
} else {
217+
sb.append(v.asInstanceOf[Char])
218+
}
219+
}
220+
}
221+
}
222+
fields.append(sb.toString())
223+
}
224+
if (escaped) {
225+
throw new IllegalStateException(s"Found hanging escape char.")
226+
}
227+
fields.toArray
228+
}
229+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.input
19+
20+
import java.io.{DataOutputStream, File, FileOutputStream}
21+
22+
import scala.language.implicitConversions
23+
24+
import com.google.common.io.Files
25+
import org.apache.hadoop.conf.Configuration
26+
import org.scalatest.{BeforeAndAfterAll, FunSuite}
27+
28+
import org.apache.spark.{Logging, SparkContext}
29+
import org.apache.spark.SparkContext._
30+
import org.apache.spark.input.EscapedTextInputFormat._
31+
import org.apache.spark.util.Utils
32+
33+
class EscapedTextInputFormatSuite extends FunSuite with BeforeAndAfterAll with Logging {
34+
35+
import EscapedTextInputFormatSuite._
36+
37+
private var sc: SparkContext = _
38+
39+
override def beforeAll() {
40+
sc = new SparkContext("local", "test")
41+
42+
// Set the block size of local file system to test whether files are split right or not.
43+
sc.hadoopConfiguration.setLong("fs.local.block.size", 4)
44+
}
45+
46+
override def afterAll() {
47+
sc.stop()
48+
}
49+
50+
private def writeToFile(contents: String, file: File) = {
51+
val bytes = contents.getBytes
52+
val out = new DataOutputStream(new FileOutputStream(file))
53+
out.write(bytes, 0, bytes.length)
54+
out.close()
55+
}
56+
57+
private def escape(records: Set[Seq[String]], delimiter: Char): String = {
58+
require(delimiter != '\\' && delimiter != '\n')
59+
records.map { r =>
60+
r.map { f =>
61+
f.replace("\\", "\\\\")
62+
.replace("\n", "\\\n")
63+
.replace(delimiter, "\\" + delimiter)
64+
}.mkString(delimiter)
65+
}.mkString("", "\n", "\n")
66+
}
67+
68+
private final val TAB = '\t'
69+
70+
private val records = Set(
71+
Seq("a\n", DEFAULT_DELIMITER + "b\\"),
72+
Seq("c", TAB + "d"),
73+
Seq("\ne", "\\\\f"))
74+
75+
private def withTempDir(func: File => Unit): Unit = {
76+
val dir = Files.createTempDir()
77+
dir.deleteOnExit()
78+
logDebug(s"dir: $dir")
79+
func(dir)
80+
Utils.deleteRecursively(dir)
81+
}
82+
83+
test("default delimiter") {
84+
withTempDir { dir =>
85+
val escaped = escape(records, DEFAULT_DELIMITER)
86+
writeToFile(escaped, new File(dir, "part-00000"))
87+
88+
val rdd = sc.newAPIHadoopFile(dir.toString, classOf[EscapedTextInputFormat],
89+
classOf[Long], classOf[Array[String]])
90+
assert(rdd.partitions.size > 3) // so there will be empty partitions
91+
92+
val actual = rdd.values.map(_.toSeq).collect().toSet
93+
assert(actual === records)
94+
}
95+
}
96+
97+
test("customized delimiter") {
98+
withTempDir { dir =>
99+
val escaped = escape(records, TAB)
100+
writeToFile(escaped, new File(dir, "part-00000"))
101+
102+
val conf = new Configuration
103+
conf.set(KEY_DELIMITER, TAB)
104+
105+
val rdd = sc.newAPIHadoopFile(dir.toString, classOf[EscapedTextInputFormat],
106+
classOf[Long], classOf[Array[String]], conf)
107+
assert(rdd.partitions.size > 3) // so their will be empty partitions
108+
109+
val actual = rdd.values.map(_.toSeq).collect().toSet
110+
assert(actual === records)
111+
}
112+
}
113+
}
114+
115+
object EscapedTextInputFormatSuite {
116+
117+
implicit def charToString(c: Char): String = c.toString
118+
}

0 commit comments

Comments
 (0)