diff --git a/build.sbt b/build.sbt index f87a73cc..3316670e 100755 --- a/build.sbt +++ b/build.sbt @@ -1,6 +1,6 @@ name := "spark-xml" -version := "0.4.1" +version := "0.4.2" organization := "com.databricks" diff --git a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala index 2783726b..ca34eb49 100644 --- a/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala +++ b/src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala @@ -54,6 +54,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { private var startTag: Array[Byte] = _ private var currentStartTag: Array[Byte] = _ private var endTag: Array[Byte] = _ + private var endEmptyTag: Array[Byte] = _ private var space: Array[Byte] = _ private var angleBracket: Array[Byte] = _ @@ -75,6 +76,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { Charset.forName(conf.get(XmlInputFormat.ENCODING_KEY, XmlOptions.DEFAULT_CHARSET)) startTag = conf.get(XmlInputFormat.START_TAG_KEY).getBytes(charset) endTag = conf.get(XmlInputFormat.END_TAG_KEY).getBytes(charset) + endEmptyTag = "/>".getBytes(charset) space = " ".getBytes(charset) angleBracket = ">".getBytes(charset) require(startTag != null, "Start tag cannot be null.") @@ -187,10 +189,16 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { false } + private def checkEmptyTag(currentLetter: Int, position: Int): Boolean = { + if (position >= endEmptyTag.length) false + else currentLetter == endEmptyTag(position) + } + private def readUntilEndElement(): Boolean = { var si = 0 var ei = 0 var depth = 0 + while (true) { val rb = in.read() if (rb == -1) { @@ -199,7 +207,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { } else { buffer.write(rb) val b = rb.toByte - if (b == startTag(si) && b == endTag(ei)) { + if (b == startTag(si) && (b == endTag(ei) || checkEmptyTag(b, ei))) { // In start tag or end tag. si += 1 ei += 1 @@ -214,8 +222,9 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] { si += 1 ei = 0 } - } else if (b == endTag(ei)) { - if (ei >= endTag.length - 1) { + } else if (b == endTag(ei) || checkEmptyTag(b, ei)) { + if ((b == endTag(ei) && ei >= endTag.length - 1) || + (checkEmptyTag(b, ei) && ei >= endEmptyTag.length - 1)) { if (depth == 0) { // Found closing end tag. return true diff --git a/src/main/scala/com/databricks/spark/xml/XmlOptions.scala b/src/main/scala/com/databricks/spark/xml/XmlOptions.scala index f89166ab..bc9e26f8 100644 --- a/src/main/scala/com/databricks/spark/xml/XmlOptions.scala +++ b/src/main/scala/com/databricks/spark/xml/XmlOptions.scala @@ -57,12 +57,13 @@ private[xml] class XmlOptions( logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") } + require(attributePrefix.nonEmpty, "'attributePrefix' option should not be empty string.") + val failFast = ParseModes.isFailFastMode(parseMode) val dropMalformed = ParseModes.isDropMalformedMode(parseMode) val permissive = ParseModes.isPermissiveMode(parseMode) require(rowTag.nonEmpty, "'rowTag' option should not be empty string.") - require(attributePrefix.nonEmpty, "'attributePrefix' option should not be empty string.") require(valueTag.nonEmpty, "'valueTag' option should not be empty string.") require(valueTag != attributePrefix, "'valueTag' and 'attributePrefix' options should not be the same.") diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala index 7a03fc8e..a8c75c69 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParser.scala @@ -234,16 +234,15 @@ private[xml] object StaxXmlParser { options: XmlOptions, rootAttributes: Array[Attribute] = Array.empty): Row = { val row = new Array[Any](schema.length) + val nameToIndex = schema.map(_.name).zipWithIndex.toMap + // If there are attributes, then we process them first. + convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) => + nameToIndex.get(f).foreach { row(_) = v } + } var shouldStop = false while (!shouldStop) { parser.nextEvent match { case e: StartElement => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap - // If there are attributes, then we process them first. - convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) => - nameToIndex.get(f).foreach { row(_) = v } - } - val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray val field = e.asStartElement.getName.getLocalPart diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala index 8f770750..03363c9f 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala @@ -23,6 +23,7 @@ private[xml] object StaxXmlParserUtils { def checkEndElement(parser: XMLEventReader): Boolean = { parser.peek match { case _: EndElement => true + case _: EndDocument => true case _: StartElement => false case _ => // When other events are found here rather than `EndElement` or `StartElement` diff --git a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala index bbdb37cc..64bce120 100644 --- a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala @@ -170,18 +170,17 @@ private[xml] object InferSchema { rootAttributes: Array[Attribute] = Array.empty): DataType = { val builder = Seq.newBuilder[StructField] val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]] + // If there are attributes, then we should process them first. + val rootValuesMap = + StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) + rootValuesMap.foreach { + case (f, v) => + nameToDataType += (f -> ArrayBuffer(inferFrom(v, options))) + } var shouldStop = false while (!shouldStop) { parser.nextEvent match { case e: StartElement => - // If there are attributes, then we should process them first. - val rootValuesMap = - StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) - rootValuesMap.foreach { - case (f, v) => - nameToDataType += (f -> ArrayBuffer(inferFrom(v, options))) - } - val attributes = e.getAttributes.map(_.asInstanceOf[Attribute]).toArray val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) val inferredType = inferField(parser, options) match { @@ -222,7 +221,7 @@ private[xml] object InferSchema { } // We need to manually merges the fields having the sames so that // This can be inferred as ArrayType. - nameToDataType.foreach{ + nameToDataType.foreach { case (field, dataTypes) if dataTypes.length > 1 => val elementType = dataTypes.reduceLeft(InferSchema.compatibleType(options)) builder += StructField(field, ArrayType(elementType), nullable = true) diff --git a/src/test/resources/fias_house.xml b/src/test/resources/fias_house.xml new file mode 100644 index 00000000..ab3e5140 --- /dev/null +++ b/src/test/resources/fias_house.xml @@ -0,0 +1,182 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala index d3eb8966..281a7447 100755 --- a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala @@ -28,7 +28,8 @@ import org.apache.hadoop.io.compress.GzipCodec import org.scalatest.{BeforeAndAfterAll, FunSuite} import com.databricks.spark.xml.XmlOptions._ -import com.databricks.spark.xml.util.ParseModes +import com.databricks.spark.xml.util.{ParseModes, XmlFile} + import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SQLContext, SaveMode} import org.apache.spark.{SparkConf, SparkContext, SparkException} @@ -59,11 +60,13 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val simpleNestedObjects = "src/test/resources/simple-nested-objects.xml" val nestedElementWithNameOfParent = "src/test/resources/nested-element-with-name-of-parent.xml" val booksMalformedAttributes = "src/test/resources/books-malformed-attributes.xml" + val fiasHouse = "src/test/resources/fias_house.xml" val booksTag = "book" val booksRootTag = "books" val topicsTag = "Topic" val agesTag = "person" + val fiasRowTag = "House" val numAges = 3 val numCars = 3 @@ -71,6 +74,7 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { val numBooksComplicated = 3 val numTopics = 1 val numGPS = 2 + val numFiasHouses = 37 private var sqlContext: SQLContext = _ @@ -903,4 +907,14 @@ class XmlSuite extends FunSuite with BeforeAndAfterAll { assert(results(0)(0) === "bk111") assert(results(1)(0) === "bk112") } + + test("read utf-8 encoded file with empty tag") { + val df = sqlContext.read.format("xml") + .option("excludeAttribute", "false") + .option("rowTag", fiasRowTag) + .xml(fiasHouse) + + assert(df.collect().length == numFiasHouses) + assert(df.select().where("_HOUSEID is null").count() == 0) + } } diff --git a/src/test/scala/com/databricks/spark/xml/util/XmlFileSuite.scala b/src/test/scala/com/databricks/spark/xml/util/XmlFileSuite.scala index 25308bee..3f1162a1 100644 --- a/src/test/scala/com/databricks/spark/xml/util/XmlFileSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/util/XmlFileSuite.scala @@ -30,6 +30,10 @@ class XmlFileSuite extends FunSuite with BeforeAndAfterAll { val numBooks = 12 val numBooksUnicodeInTagName = 3 + val fiasHouse = "src/test/resources/fias_house.xml" + val fiasRowTag = "House" + val numHouses = 37 + val utf8 = "utf-8" private var sparkContext: SparkContext = _ @@ -58,6 +62,12 @@ class XmlFileSuite extends FunSuite with BeforeAndAfterAll { assert(baseRDD.count() === numBooksUnicodeInTagName) } + test("read utf-8 encoded file with empty tag") { + val baseRDD = XmlFile.withCharset(sparkContext, fiasHouse, utf8, rowTag = fiasRowTag) + assert(baseRDD.count() == numHouses) + baseRDD.collect().foreach(x => assert(x.contains("/>"))) + } + test("unsupported charset") { val exception = intercept[UnsupportedCharsetException] { XmlFile.withCharset(sparkContext, booksFile, "frylock", rowTag = booksFileTag).count()