Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name := "spark-xml"

version := "0.4.1"
version := "0.4.2"

organization := "com.databricks"

Expand Down
15 changes: 12 additions & 3 deletions src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
private var startTag: Array[Byte] = _
private var currentStartTag: Array[Byte] = _
private var endTag: Array[Byte] = _
private var endEmptyTag: Array[Byte] = _
private var space: Array[Byte] = _
private var angleBracket: Array[Byte] = _

Expand All @@ -75,6 +76,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
Charset.forName(conf.get(XmlInputFormat.ENCODING_KEY, XmlOptions.DEFAULT_CHARSET))
startTag = conf.get(XmlInputFormat.START_TAG_KEY).getBytes(charset)
endTag = conf.get(XmlInputFormat.END_TAG_KEY).getBytes(charset)
endEmptyTag = "/>".getBytes(charset)
space = " ".getBytes(charset)
angleBracket = ">".getBytes(charset)
require(startTag != null, "Start tag cannot be null.")
Expand Down Expand Up @@ -187,10 +189,16 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
false
}

private def checkEmptyTag(currentLetter: Int, position: Int): Boolean = {
if (position >= endEmptyTag.length) false
else currentLetter == endEmptyTag(position)
}

private def readUntilEndElement(): Boolean = {
var si = 0
var ei = 0
var depth = 0

while (true) {
val rb = in.read()
if (rb == -1) {
Expand All @@ -199,7 +207,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
} else {
buffer.write(rb)
val b = rb.toByte
if (b == startTag(si) && b == endTag(ei)) {
if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei))) {
// In start tag or end tag.
si += 1
ei += 1
Expand All @@ -214,8 +222,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
si += 1
ei = 0
}
} else if (b == endTag(ei)) {
if (ei >= endTag.length - 1) {
} else if (b == endTag(ei) || checkEmptyTag(b, ei)) {
if ((b == endTag(ei) && ei >= endTag.length - 1) ||
(checkEmptyTag(b, ei) && ei >= endEmptyTag.length - 1)) {
if (depth == 0) {
// Found closing end tag.
return true
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/databricks/spark/xml/XmlOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ private[xml] class XmlOptions(
logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
}

require(attributePrefix.nonEmpty, "'attributePrefix' option should not be empty string.")

val failFast = ParseModes.isFailFastMode(parseMode)
val dropMalformed = ParseModes.isDropMalformedMode(parseMode)
val permissive = ParseModes.isPermissiveMode(parseMode)

require(rowTag.nonEmpty, "'rowTag' option should not be empty string.")
require(attributePrefix.nonEmpty, "'attributePrefix' option should not be empty string.")
require(valueTag.nonEmpty, "'valueTag' option should not be empty string.")
require(valueTag != attributePrefix,
"'valueTag' and 'attributePrefix' options should not be the same.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,15 @@ private[xml] object StaxXmlParser {
options: XmlOptions,
rootAttributes: Array[Attribute] = Array.empty): Row = {
val row = new Array[Any](schema.length)
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
// If there are attributes, then we process them first.
convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) =>
nameToIndex.get(f).foreach { row(_) = v }
}
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
// If there are attributes, then we process them first.
convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) =>
nameToIndex.get(f).foreach { row(_) = v }
}

val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray
val field = e.asStartElement.getName.getLocalPart

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ private[xml] object StaxXmlParserUtils {
def checkEndElement(parser: XMLEventReader): Boolean = {
parser.peek match {
case _: EndElement => true
case _: EndDocument => true
case _: StartElement => false
case _ =>
// When other events are found here rather than `EndElement` or `StartElement`
Expand Down
17 changes: 8 additions & 9 deletions src/main/scala/com/databricks/spark/xml/util/InferSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,17 @@ private[xml] object InferSchema {
rootAttributes: Array[Attribute] = Array.empty): DataType = {
val builder = Seq.newBuilder[StructField]
val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]]
// If there are attributes, then we should process them first.
val rootValuesMap =
StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
rootValuesMap.foreach {
case (f, v) =>
nameToDataType += (f -> ArrayBuffer(inferFrom(v, options)))
}
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
// If there are attributes, then we should process them first.
val rootValuesMap =
StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
rootValuesMap.foreach {
case (f, v) =>
nameToDataType += (f -> ArrayBuffer(inferFrom(v, options)))
}

val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
val inferredType = inferField(parser, options) match {
Expand Down Expand Up @@ -222,7 +221,7 @@ private[xml] object InferSchema {
}
// We need to manually merges the fields having the sames so that
// This can be inferred as ArrayType.
nameToDataType.foreach{
nameToDataType.foreach {
case (field, dataTypes) if dataTypes.length > 1 =>
val elementType = dataTypes.reduceLeft(InferSchema.compatibleType(options))
builder += StructField(field, ArrayType(elementType), nullable = true)
Expand Down
Loading