diff --git a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala index e4f49eb8..5a981e2f 100644 --- a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala +++ b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala @@ -15,12 +15,12 @@ */ package com.databricks.spark.xml -import java.io.{InputStream, IOException} +import java.io.{IOException, InputStream, InputStreamReader, Reader} import java.nio.charset.Charset import org.apache.hadoop.fs.Seekable import org.apache.hadoop.io.compress._ -import org.apache.hadoop.io.{DataOutputBuffer, LongWritable, Text} +import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{FileSplit, TextInputFormat} @@ -50,34 +50,26 @@ object XmlInputFormat { * as specified by the start tag and end tag */ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { - private var startTag: Array[Byte] = _ - private var currentStartTag: Array[Byte] = _ - private var endTag: Array[Byte] = _ - private var endEmptyTag: Array[Byte] = _ - private var whiteSpaces: Seq[Array[Byte]] = _ - private var angleBracket: Array[Byte] = _ + private var startTag: String = _ + private var currentStartTag: String = _ + private var endTag: String = _ private var currentKey: LongWritable = _ private var currentValue: Text = _ - private var start: Long = _ private var end: Long = _ - private var in: InputStream = _ + private var reader: Reader = _ private var filePosition: Seekable = _ private var decompressor: Decompressor = _ - - private val buffer: DataOutputBuffer = new DataOutputBuffer + private var buffer = new StringBuilder() override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = { val fileSplit = split.asInstanceOf[FileSplit] val conf = context.getConfiguration val charset = Charset.forName(conf.get(XmlInputFormat.ENCODING_KEY, XmlOptions.DEFAULT_CHARSET)) - startTag = conf.get(XmlInputFormat.START_TAG_KEY).getBytes(charset) - endTag = conf.get(XmlInputFormat.END_TAG_KEY).getBytes(charset) - endEmptyTag = "/>".getBytes(charset) - whiteSpaces = Seq(" ", "\r", "\n").map(_.getBytes(charset)) - angleBracket = ">".getBytes(charset) + startTag = conf.get(XmlInputFormat.START_TAG_KEY) + endTag = conf.get(XmlInputFormat.END_TAG_KEY) start = fileSplit.getStart end = start + fileSplit.getLength @@ -86,6 +78,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { val fs = path.getFileSystem(conf) val fsin = fs.open(fileSplit.getPath) + var in: InputStream = null val codec = new CompressionCodecFactory(conf).getCodec(path) if (null != codec) { decompressor = CodecPool.getDecompressor(codec) @@ -117,6 +110,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { filePosition = fsin filePosition.seek(start) } + reader = new InputStreamReader(in, charset) } override def nextKeyValue: Boolean = { @@ -136,77 +130,90 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { private def next(key: LongWritable, value: Text): Boolean = { if (readUntilStartElement()) { try { - buffer.write(currentStartTag) - if (readUntilEndElement()) { + buffer.append(currentStartTag) + if (readUntilEndElement(currentStartTag.endsWith(">"))) { key.set(filePosition.getPos) - value.set(buffer.getData, 0, buffer.getLength) - true - } else { - false + value.set(buffer.toString()) + return true } } finally { - buffer.reset + buffer = new StringBuilder() } - } else { - false } + false } private def readUntilStartElement(): Boolean = { currentStartTag = startTag var i = 0 while (true) { - val b = in.read() - if (b == -1 || (i == 0 && filePosition.getPos > end)) { + val cOrEOF = reader.read() + if (cOrEOF == -1 || (i == 0 && filePosition.getPos > end)) { // End of file or end of split. return false + } + val c = cOrEOF.toChar + if (c == startTag(i)) { + if (i >= startTag.length - 1) { + // Found start tag. + return true + } + // else in start tag + i += 1 } else { - if (b.toByte == startTag(i)) { - if (i >= startTag.length - 1) { - // Found start tag. - return true - } else { - // In start tag. - i += 1 - } - } else { - if (i == (startTag.length - angleBracket.length) && checkAttributes(b)) { - // Found start tag with attributes. - return true - } else { - // Not in start tag. - i = 0 - } + // if doesn't match the closing angle bracket, check if followed by attributes + if (i == (startTag.length - 1) && Character.isWhitespace(c)) { + // Found start tag with attributes. Remember to write with following whitespace + // char, not angle bracket + currentStartTag = startTag.dropRight(1) + c + return true } + // else not in start tag + i = 0 } } // Unreachable. false } - private def checkEmptyTag(currentLetter: Int, position: Int): Boolean = { - if (position >= endEmptyTag.length) false - else currentLetter == endEmptyTag(position) - } - - private def readUntilEndElement(): Boolean = { + private def readUntilEndElement(startTagClosed: Boolean): Boolean = { + // Index into the start or end tag that has matched so far var si = 0 var ei = 0 + // How many other start tags enclose the one that's started already? var depth = 0 + // Previously read character + var prevC = '\u0000' + + // The current start tag already found may or may not have terminated with + // a '>' as it may have attributes we read here. If not, we search for + // a self-close tag, but only until a non-self-closing end to the start + // tag is found + var canSelfClose = !startTagClosed while (true) { - val rb = in.read() - if (rb == -1) { + + val cOrEOF = reader.read() + if (cOrEOF == -1) { // End of file (ignore end of split). return false - } else { - buffer.write(rb) - val b = rb.toByte - if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei))) { + } + + val c = cOrEOF.toChar + buffer.append(c) + + if (c == '>' && prevC != '/') { + canSelfClose = false + } + + // Still matching a start tag? + if (c == startTag(si)) { + // Still also matching an end tag? + if (c == endTag(ei)) { // In start tag or end tag. si += 1 ei += 1 - } else if (b == startTag(si)) { + } else { if (si >= startTag.length - 1) { // Found start tag. si = 0 @@ -217,57 +224,39 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { si += 1 ei = 0 } - } else if (b == endTag(ei) || checkEmptyTag(b, ei)) { - if ((b == endTag(ei) && ei >= endTag.length - 1) || - (checkEmptyTag(b, ei) && ei >= endEmptyTag.length - 1)) { - if (depth == 0) { - // Found closing end tag. - return true - } else { - // Found nested end tag. - si = 0 - ei = 0 - depth -= 1 - } - } else { - // In end tag. - si = 0 - ei += 1 + } + } else if (c == endTag(ei)) { + if (ei >= endTag.length - 1) { + if (depth == 0) { + // Found closing end tag. + return true } - } else { - // Not in start tag or end tag. + // else found nested end tag. si = 0 ei = 0 + depth -= 1 + } else { + // In end tag. + si = 0 + ei += 1 } - } - } - // Unreachable. - false - } - - private def checkAttributes(current: Int): Boolean = { - val matchedSpace = Array.fill(whiteSpaces.length)(true) - val maxLen = whiteSpaces.map(_.length).max - var b = current - // Loop over input until looking for whitespace until max length of space is reached - for (i <- 0 until maxLen) { - // Loop over all whitespace bytes - for (j <- whiteSpaces.indices) { - val len = whiteSpaces(j).length - // If match so far but current byte doesn't match, rule out char - if (matchedSpace(j) && i < len && (b != whiteSpaces(j)(i))) { - matchedSpace(j) = false - } - // If checked last char and still matches, this is whitespace - if (i == len - 1 && matchedSpace(j)) { - // take tag plus the whitespace bytes - currentStartTag = - startTag.take(startTag.length - angleBracket.length) ++ whiteSpaces(j) + } else if (c == '>' && prevC == '/' && canSelfClose) { + if (depth == 0) { + // found a self-closing tag (end tag) return true } + // else found self-closing nested tag (end tag) + si = 0 + ei = 0 + depth -= 1 + } else { + // Not in start tag or end tag. + si = 0 + ei = 0 } - b = in.read + prevC = c } + // Unreachable. false } @@ -279,8 +268,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { def close(): Unit = { try { - if (in != null) { - in.close() + if (reader != null) { + reader.close() + reader = null } } finally { if (decompressor != null) { diff --git a/src/test/resources/self-closing-tag.xml b/src/test/resources/self-closing-tag.xml new file mode 100644 index 00000000..c3057b22 --- /dev/null +++ b/src/test/resources/self-closing-tag.xml @@ -0,0 +1,6 @@ + + + 1 + + + diff --git a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala index 03f80aa9..a4fcbee3 100755 --- a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala @@ -63,6 +63,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val attributesStartWithNewLine = "src/test/resources/attributesStartWithNewLine.xml" val attributesStartWithNewLineLF = "src/test/resources/attributesStartWithNewLineLF.xml" val attributesStartWithNewLineCR = "src/test/resources/attributesStartWithNewLineCR.xml" + val selfClosingTag = "src/test/resources/self-closing-tag.xml" val booksTag = "book" val booksRootTag = "books" @@ -941,4 +942,19 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { assert(df.count() == rowsCount) } } + + test("Produces correct result for a row with a self closing tag inside") { + val schema = StructType(Seq( + StructField("non-empty-tag", IntegerType, nullable = true), + StructField("self-closing-tag", IntegerType, nullable = true) + )) + + val result = new XmlReader() + .withSchema(schema) + .xmlFile(spark, selfClosingTag) + .collect() + + assert(result(0).toSeq === Seq(1, null)) + } + }