Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 92 additions & 102 deletions src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
*/
package com.databricks.spark.xml

import java.io.{InputStream, IOException}
import java.io.{IOException, InputStream, InputStreamReader, Reader}
import java.nio.charset.Charset

import org.apache.hadoop.fs.Seekable
import org.apache.hadoop.io.compress._
import org.apache.hadoop.io.{DataOutputBuffer, LongWritable, Text}
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{FileSplit, TextInputFormat}

Expand Down Expand Up @@ -50,34 +50,26 @@ object XmlInputFormat {
* as specified by the start tag and end tag
*/
private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
private var startTag: Array[Byte] = _
private var currentStartTag: Array[Byte] = _
private var endTag: Array[Byte] = _
private var endEmptyTag: Array[Byte] = _
private var whiteSpaces: Seq[Array[Byte]] = _
private var angleBracket: Array[Byte] = _

private var startTag: String = _
private var currentStartTag: String = _
private var endTag: String = _
private var currentKey: LongWritable = _
private var currentValue: Text = _

private var start: Long = _
private var end: Long = _
private var in: InputStream = _
private var reader: Reader = _
private var filePosition: Seekable = _
private var decompressor: Decompressor = _

private val buffer: DataOutputBuffer = new DataOutputBuffer
private var buffer = new StringBuilder()

override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {
val fileSplit = split.asInstanceOf[FileSplit]
val conf = context.getConfiguration
val charset =
Charset.forName(conf.get(XmlInputFormat.ENCODING_KEY, XmlOptions.DEFAULT_CHARSET))
startTag = conf.get(XmlInputFormat.START_TAG_KEY).getBytes(charset)
endTag = conf.get(XmlInputFormat.END_TAG_KEY).getBytes(charset)
endEmptyTag = "/>".getBytes(charset)
whiteSpaces = Seq(" ", "\r", "\n").map(_.getBytes(charset))
angleBracket = ">".getBytes(charset)
startTag = conf.get(XmlInputFormat.START_TAG_KEY)
endTag = conf.get(XmlInputFormat.END_TAG_KEY)
start = fileSplit.getStart
end = start + fileSplit.getLength

Expand All @@ -86,6 +78,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
val fs = path.getFileSystem(conf)
val fsin = fs.open(fileSplit.getPath)

var in: InputStream = null
val codec = new CompressionCodecFactory(conf).getCodec(path)
if (null != codec) {
decompressor = CodecPool.getDecompressor(codec)
Expand Down Expand Up @@ -117,6 +110,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
filePosition = fsin
filePosition.seek(start)
}
reader = new InputStreamReader(in, charset)
}

override def nextKeyValue: Boolean = {
Expand All @@ -136,77 +130,90 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
private def next(key: LongWritable, value: Text): Boolean = {
if (readUntilStartElement()) {
try {
buffer.write(currentStartTag)
if (readUntilEndElement()) {
buffer.append(currentStartTag)
if (readUntilEndElement(currentStartTag.endsWith(">"))) {
key.set(filePosition.getPos)
value.set(buffer.getData, 0, buffer.getLength)
true
} else {
false
value.set(buffer.toString())
return true
}
} finally {
buffer.reset
buffer = new StringBuilder()
}
} else {
false
}
false
}

private def readUntilStartElement(): Boolean = {
currentStartTag = startTag
var i = 0
while (true) {
val b = in.read()
if (b == -1 || (i == 0 && filePosition.getPos > end)) {
val cOrEOF = reader.read()
if (cOrEOF == -1 || (i == 0 && filePosition.getPos > end)) {
// End of file or end of split.
return false
}
val c = cOrEOF.toChar
if (c == startTag(i)) {
if (i >= startTag.length - 1) {
// Found start tag.
return true
}
// else in start tag
i += 1
} else {
if (b.toByte == startTag(i)) {
if (i >= startTag.length - 1) {
// Found start tag.
return true
} else {
// In start tag.
i += 1
}
} else {
if (i == (startTag.length - angleBracket.length) && checkAttributes(b)) {
// Found start tag with attributes.
return true
} else {
// Not in start tag.
i = 0
}
// if doesn't match the closing angle bracket, check if followed by attributes
if (i == (startTag.length - 1) && Character.isWhitespace(c)) {
// Found start tag with attributes. Remember to write with following whitespace
// char, not angle bracket
currentStartTag = startTag.dropRight(1) + c
return true
}
// else not in start tag
i = 0
}
}
// Unreachable.
false
}

private def checkEmptyTag(currentLetter: Int, position: Int): Boolean = {
if (position >= endEmptyTag.length) false
else currentLetter == endEmptyTag(position)
}

private def readUntilEndElement(): Boolean = {
private def readUntilEndElement(startTagClosed: Boolean): Boolean = {
// Index into the start or end tag that has matched so far
var si = 0
var ei = 0
// How many other start tags enclose the one that's started already?
var depth = 0
// Previously read character
var prevC = '\u0000'

// The current start tag already found may or may not have terminated with
// a '>' as it may have attributes we read here. If not, we search for
// a self-close tag, but only until a non-self-closing end to the start
// tag is found
var canSelfClose = !startTagClosed

while (true) {
val rb = in.read()
if (rb == -1) {

val cOrEOF = reader.read()
if (cOrEOF == -1) {
// End of file (ignore end of split).
return false
} else {
buffer.write(rb)
val b = rb.toByte
if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei))) {
}

val c = cOrEOF.toChar
buffer.append(c)

if (c == '>' && prevC != '/') {
canSelfClose = false
}

// Still matching a start tag?
if (c == startTag(si)) {
// Still also matching an end tag?
if (c == endTag(ei)) {
// In start tag or end tag.
si += 1
ei += 1
} else if (b == startTag(si)) {
} else {
if (si >= startTag.length - 1) {
// Found start tag.
si = 0
Expand All @@ -217,57 +224,39 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
si += 1
ei = 0
}
} else if (b == endTag(ei) || checkEmptyTag(b, ei)) {
if ((b == endTag(ei) && ei >= endTag.length - 1) ||
(checkEmptyTag(b, ei) && ei >= endEmptyTag.length - 1)) {
if (depth == 0) {
// Found closing end tag.
return true
} else {
// Found nested end tag.
si = 0
ei = 0
depth -= 1
}
} else {
// In end tag.
si = 0
ei += 1
}
} else if (c == endTag(ei)) {
if (ei >= endTag.length - 1) {
if (depth == 0) {
// Found closing end tag.
return true
}
} else {
// Not in start tag or end tag.
// else found nested end tag.
si = 0
ei = 0
depth -= 1
} else {
// In end tag.
si = 0
ei += 1
}
}
}
// Unreachable.
false
}

private def checkAttributes(current: Int): Boolean = {
val matchedSpace = Array.fill(whiteSpaces.length)(true)
val maxLen = whiteSpaces.map(_.length).max
var b = current
// Loop over input until looking for whitespace until max length of space is reached
for (i <- 0 until maxLen) {
// Loop over all whitespace bytes
for (j <- whiteSpaces.indices) {
val len = whiteSpaces(j).length
// If match so far but current byte doesn't match, rule out char
if (matchedSpace(j) && i < len && (b != whiteSpaces(j)(i))) {
matchedSpace(j) = false
}
// If checked last char and still matches, this is whitespace
if (i == len - 1 && matchedSpace(j)) {
// take tag plus the whitespace bytes
currentStartTag =
startTag.take(startTag.length - angleBracket.length) ++ whiteSpaces(j)
} else if (c == '>' && prevC == '/' && canSelfClose) {
if (depth == 0) {
// found a self-closing tag (end tag)
return true
}
// else found self-closing nested tag (end tag)
si = 0
ei = 0
depth -= 1
} else {
// Not in start tag or end tag.
si = 0
ei = 0
}
b = in.read
prevC = c
}
// Unreachable.
false
}

Expand All @@ -279,8 +268,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {

def close(): Unit = {
try {
if (in != null) {
in.close()
if (reader != null) {
reader.close()
reader = null
}
} finally {
if (decompressor != null) {
Expand Down
6 changes: 6 additions & 0 deletions src/test/resources/self-closing-tag.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<ROWSET>
<ROW>
<non-empty-tag>1</non-empty-tag>
<self-closing-tag/>
</ROW>
</ROWSET>
16 changes: 16 additions & 0 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val attributesStartWithNewLine = "src/test/resources/attributesStartWithNewLine.xml"
val attributesStartWithNewLineLF = "src/test/resources/attributesStartWithNewLineLF.xml"
val attributesStartWithNewLineCR = "src/test/resources/attributesStartWithNewLineCR.xml"
val selfClosingTag = "src/test/resources/self-closing-tag.xml"

val booksTag = "book"
val booksRootTag = "books"
Expand Down Expand Up @@ -941,4 +942,19 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
assert(df.count() == rowsCount)
}
}

test("Produces correct result for a row with a self closing tag inside") {
val schema = StructType(Seq(
StructField("non-empty-tag", IntegerType, nullable = true),
StructField("self-closing-tag", IntegerType, nullable = true)
))

val result = new XmlReader()
.withSchema(schema)
.xmlFile(spark, selfClosingTag)
.collect()

assert(result(0).toSeq === Seq(1, null))
}

}