Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.databricks.spark.xml

import java.io.{InputStream, IOException}
import java.io.{IOException, InputStream}
import java.nio.charset.Charset

import org.apache.hadoop.conf.Configuration
Expand All @@ -25,14 +25,17 @@ import org.apache.hadoop.io.{DataOutputBuffer, LongWritable, Text}
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.{FileSplit, TextInputFormat}

import scala.collection.mutable.ArrayBuffer

/**
* Reads records that are delimited by a specific start/end tag.
*/
class XmlInputFormat extends TextInputFormat {

override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[LongWritable, Text] = {
split: InputSplit,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise this needs to be reverted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't revert this change, because goal scalacheckstyle will fail on it.

image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line isn't too long if you don't add the deep continuation indent. The code passed style checks before.

context: TaskAttemptContext):
RecordReader[LongWritable, Text] = {
new XmlRecordReader
}
}
Expand All @@ -47,9 +50,9 @@ object XmlInputFormat {
}

/**
* XMLRecordReader class to read through a given xml document to output xml blocks as records
* as specified by the start tag and end tag
*/
* XMLRecordReader class to read through a given xml document to output xml blocks as records
* as specified by the start tag and end tag
*/
private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
private var startTag: Array[Byte] = _
private var currentStartTag: Array[Byte] = _
Expand Down Expand Up @@ -111,7 +114,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
// So we have a split that is only part of a file stored using
// a Compression codec that cannot be split.
throw new IOException("Cannot seek in " +
codec.getClass.getSimpleName + " compressed stream")
codec.getClass.getSimpleName + " compressed stream")
}
val cIn = c.createInputStream(fsin, decompressor)
in = cIn
Expand All @@ -131,13 +134,13 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
}

/**
* Finds the start of the next record.
* It treats data from `startTag` and `endTag` as a record.
*
* @param key the current key that will be written
* @param value the object that will be written
* @return whether it reads successfully
*/
* Finds the start of the next record.
* It treats data from `startTag` and `endTag` as a record.
*
* @param key the current key that will be written
* @param value the object that will be written
* @return whether it reads successfully
*/
private def next(key: LongWritable, value: Text): Boolean = {
if (readUntilStartElement()) {
try {
Expand Down Expand Up @@ -189,9 +192,24 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
false
}

private def checkEmptyTag(currentLetter: Int, position: Int): Boolean = {
private def checkEmptyTag(currentLetter: Int, position: Int,
buffer: DataOutputBuffer): Boolean = {
def checkStartTagBefore = {
val startAngleInByte = '<'.toByte
val spaceInByte = ' '.toByte
val rootTagName = buffer.getData
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why define a method here?
I don't think this is efficient enough as it makes a few copies of much of the buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's used only here.

I don't think this is efficient enough as it makes a few copies of much of the buffer.

maybe, but this code run only in rare situations
code in buffet not so easy to read

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me think about it and maybe try a different approach. I think we just need to look for a self-close tag that can only come before any other tag close.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(BTW here I meant, why define a method at all? it's only used here, as you say)

.reverse
.takeWhile(_ != startAngleInByte)
.reverse
.takeWhile(_ != spaceInByte)
val result = startAngleInByte +: rootTagName

result.sameElements(startTag.dropRight(1))
}

if (position >= endEmptyTag.length) false
else currentLetter == endEmptyTag(position)
else currentLetter == endEmptyTag(position) &&
checkStartTagBefore
}

private def readUntilEndElement(): Boolean = {
Expand All @@ -207,7 +225,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
} else {
buffer.write(rb)
val b = rb.toByte
if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei))) {
if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei, buffer))) {
// In start tag or end tag.
si += 1
ei += 1
Expand All @@ -222,9 +240,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
si += 1
ei = 0
}
} else if (b == endTag(ei) || checkEmptyTag(b, ei)) {
} else if (b == endTag(ei) || checkEmptyTag(b, ei, buffer)) {
if ((b == endTag(ei) && ei >= endTag.length - 1) ||
(checkEmptyTag(b, ei) && ei >= endEmptyTag.length - 1)) {
(checkEmptyTag(b, ei, buffer) && ei >= endEmptyTag.length - 1)) {
if (depth == 0) {
// Found closing end tag.
return true
Expand Down Expand Up @@ -253,7 +271,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
private def checkAttributes(current: Int): Boolean = {
var len = 0
var b = current
while(len < space.length && b == space(len)) {
while (len < space.length && b == space(len)) {
len += 1
if (len >= space.length) {
currentStartTag = startTag.take(startTag.length - angleBracket.length) ++ space
Expand Down
6 changes: 6 additions & 0 deletions src/test/resources/self-closing-tag.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<ROWSET>
<ROW>
<non-empty-tag>1</non-empty-tag>
<self-closing-tag/>
</ROW>
</ROWSET>
18 changes: 17 additions & 1 deletion src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
val nestedElementWithNameOfParent = "src/test/resources/nested-element-with-name-of-parent.xml"
val booksMalformedAttributes = "src/test/resources/books-malformed-attributes.xml"
val fiasHouse = "src/test/resources/fias_house.xml"
val fiasHouseSimple = "src/test/resources/fias_house_simple.xml"
val selfClosingTag = "src/test/resources/self-closing-tag.xml"

val booksTag = "book"
val booksRootTag = "books"
Expand Down Expand Up @@ -904,7 +906,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
assert(results(1)(0) === "bk112")
}

test("read utf-8 encoded file with empty tag") {
test("empty tag data only in attributes") {
val df = spark.read.format("xml")
.option("excludeAttribute", "false")
.option("rowTag", fiasRowTag)
Expand All @@ -913,4 +915,18 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll {
assert(df.collect().length == numFiasHouses)
assert(df.select().where("_HOUSEID is null").count() == 0)
}

test("Produces correct result for a row with a self closing tag inside") {
val schema = StructType(Seq(
StructField("non-empty-tag", IntegerType, nullable = true),
StructField("self-closing-tag", IntegerType, nullable = true)
))

val result = new XmlReader()
.withSchema(schema)
.xmlFile(spark, selfClosingTag)
.collect()

assert(result(0).toSeq === Seq(1, null))
}
}