diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e8917682..045774a17ac1 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -187,6 +187,56 @@ def func(split, iterator): jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) return SchemaRDD(jschema_rdd, self) + def csvFile(self, path, delimiter = ",", quote = "\"", header = False): + """ + Loads a CSV file (according to RFC 4180) and returns the result as a L{SchemaRDD}. + header flag specified if first line of each file should be treated as header. + + NOTE: If there are new line characters inside quoted fields this method may fail to + parse correctly, because the two lines may be in different partitions. Use + L{SQLContext#csvRDD} to parse such files. + + >>> import tempfile, shutil + >>> csvFile = tempfile.mkdtemp() + >>> shutil.rmtree(csvFile) + >>> ofn = open(csvFile, 'w') + >>> for csvStr in csvStrings: + ... print>>ofn, csvStr + >>> ofn.close() + >>> csv = sqlCtx.csvFile(csvFile, delimiter = ", ", header = True) + >>> sqlCtx.registerRDDAsTable(csv, "csvTable") + >>> csvRes = sqlCtx.sql("SELECT Year FROM csvTable WHERE Make = 'Ford'") + >>> csvRes.collect() + [{u'Year': u'1997'}] + """ + jschema_rdd = self._ssql_ctx.csvFile(path, delimiter, quote, header) + return SchemaRDD(jschema_rdd, self) + + def csvRDD(self, rdd, delimiter = ",", quote = "\"", header = False): + """ + Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + L{SchemaRDD}. + + NOTE: If there are new line characters inside quoted fields, use wholeTextFile to + read each file into a single partition. + + >>> csvrdd = sqlCtx.csvRDD(csv, delimiter = ", ", header = True) + >>> sqlCtx.registerRDDAsTable(csvrdd, "csvTable2") + >>> csvRes = sqlCtx.sql("SELECT count(*) FROM csvTable2") + >>> csvRes.collect() == [{"c0": 3}] + True + """ + def func(split, iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + yield x.encode("utf-8") + keyed = PipelinedRDD(rdd, func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + jschema_rdd = self._ssql_ctx.csvRDD(jrdd.rdd(), delimiter, quote, header) + return SchemaRDD(jschema_rdd, self) + def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -505,6 +555,12 @@ def _test(): ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) + csvStrings = ['Year, Make, Model, Description', + '"1997", "Ford", "E350", ', + '2000, Mercury, "Cougar", "Really ""Good"" car"', + '2007, Honda, "Civic", '] + globs['csvStrings'] = csvStrings + globs['csv'] = sc.parallelize(csvStrings) globs['nestedRdd1'] = sc.parallelize([ {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, {"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd2..6d55e263cdc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.csv._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ @@ -129,6 +130,100 @@ class SQLContext(@transient val sparkContext: SparkContext) def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[SchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param schema StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile( + path: String, + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): SchemaRDD = { + val csv = sparkContext.textFile(path) + csvRDD(csv, schema, delimiter, quote, header) + } + + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[SchemaRDD]]. + * It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile( + path: String, + delimiter: String = ",", + quote: Char = '"', + header: Boolean = false): SchemaRDD = { + val csv = sparkContext.textFile(path) + csvRDD(csv, delimiter, quote, header) + } + + + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[SchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[SparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param schema StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD( + csv: RDD[String], + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): SchemaRDD = { + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, Some(schema), header)) + } + + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[SchemaRDD]]. It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[SparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD( + csv: RDD[String], + delimiter: String = ",", + quote: Char = '"', + header: Boolean = false): SchemaRDD = { + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, None, header)) + } + /** * :: Experimental :: * Creates an empty parquet file with the schema of class `A`, which can be registered as a table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 790d9ef22cf1..999d1e11b6d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.csv.CsvRDD import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} @@ -122,6 +123,100 @@ class JavaSQLContext(val sqlContext: SQLContext) { def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[JavaSchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param schema StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile( + path: String, + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + val csv = sqlContext.sparkContext.textFile(path) + csvRDD(csv, schema, delimiter, quote, header) + } + + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[JavaSchemaRDD]]. + * It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile( + path: String, + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + val csv = sqlContext.sparkContext.textFile(path) + csvRDD(csv, delimiter, quote, header) + } + + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[JavaSchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[JavaSparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param schema optional StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD( + csv: JavaRDD[String], + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, Some(schema), header)) + } + + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[JavaSchemaRDD]]. It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[JavaSparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD( + csv: JavaRDD[String], + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, None, header)) + } + + /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only * during the lifetime of this instance of SQLContext. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala new file mode 100644 index 000000000000..9180af380e5d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.Logging + +private[sql] object CsvRDD extends Logging { + + /** + * Infers schema of a CSV file. It uses the first row of the first partition to + * infer number of columns. If header flag is set, all lines that equal the first line + * are filtered before parsing. + * + * o If a line contains fewer tokens than the schema, it is padded with nulls + * o If a line has more tokens than the schema, extra tokens are ignored. + */ + private[sql] def inferSchema( + csv: RDD[String], + delimiter: String, + quote: Char, + userSchema: Option[StructType], + useHeader: Boolean): LogicalPlan = { + + val firstLine = csv.first() + val schema = userSchema match { + case Some(userSupportedSchema) => userSupportedSchema + case None => + // Assume first row is representative and use it to determine number of fields + val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() + val header = if (useHeader) { + logger.info(s"Using header line: $firstLine") + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"V$index"} + } + // By default fields are assumed to be StringType + val schemaFields = header.map { fieldName => + StructField(fieldName, StringType, nullable = true) + } + StructType(schemaFields) + } + + val numFields = schema.fields.length + logger.info(s"Parsing CSV with $numFields.") + val row = new GenericMutableRow(numFields) + val projection = schemaCaster(asAttributes(schema)) + + val parsedCSV = csv.mapPartitions { iter => + // When using header, any input line that equals firstLine is assumed to be header + val csvIter = if (useHeader) { + iter.filter(_ != firstLine) + } else { + iter + } + val tokenIter = new CsvTokenizer(csvIter, delimiter, quote) + parseCSV(tokenIter, schema.fields, projection, row) + } + + SparkLogicalPlan(ExistingRdd(asAttributes(schema), parsedCSV)) + } + + protected def schemaCaster(schema: Seq[AttributeReference]): MutableProjection = { + val startSchema = (1 to schema.length).toSeq.map( + index => new AttributeReference(s"V$index", StringType, nullable = true)()) + val casts = schema.zipWithIndex.map { case (ar, i) => Cast(startSchema(i), ar.dataType) } + new MutableProjection(casts, startSchema) + } + + private def parseCSV( + iter: Iterator[Array[String]], + schemaFields: Seq[StructField], + projection: MutableProjection, + row: GenericMutableRow): Iterator[Row] = { + iter.map { tokens => + schemaFields.zipWithIndex.foreach { + case (StructField(name, dataType, _), index) => + row.update(index, tokens(index)) + } + projection(row) + } + } + + private def asAttributes(struct: StructType): Seq[AttributeReference] = { + struct.fields.map(field => AttributeReference(field.name, field.dataType, nullable = true)()) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala new file mode 100644 index 000000000000..bb86aff5613e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import scala.collection.mutable.ArrayBuffer + +/** + * Tokenizer based on RFC 4180 for comma separated values. + * It implements an iterator that returns each tokenized line as an Array[Any]. + */ +private[sql] class CsvTokenizer( + inputIter: Iterator[String], + delimiter: String, + quote: Char) extends Iterator[Array[String]] { + + private val DELIM = delimiter.charAt(0) + private val QUOTE = quote + private val DOUBLEQUOTE = quote.toString * 2 + private val BACKSLASH = '\\' + private val NEWLINE = '\n' + + private val MAX_QUOTED_LINES = 10 + + private val delimLength = delimiter.length + + private def isDelimAt(inStr: String, index: Int): Boolean = { + inStr.substring(index, index + delimLength) == delimiter + } + + private def isDoubleQuoteAt(inStr: String, index: Int): Boolean = { + inStr.substring(index, index + 2) == DOUBLEQUOTE + } + + private def stripQuotes(inStr: StringBuilder): String = { + val end = inStr.length - 1 + if (inStr.charAt(0) == QUOTE && inStr.charAt(end) == QUOTE) { + inStr.deleteCharAt(end).deleteCharAt(0).toString() + } else { + inStr.toString() + } + } + + import QuoteState._ + + def hasNext: Boolean = inputIter.hasNext + + def next(): Array[String] = { + var curState = Unquoted + var curPosition = 0 + var startPosition = 0 + var curChar: Char = '\0' + var quotedLines = 0 + val leftOver = new StringBuilder() // Used to keep track of tokens that span multiple lines + val tokens = new ArrayBuffer[String]() + var line = inputIter.next() + '\n' + + while (curPosition < line.length) { + curChar = line.charAt(curPosition) + + if (curChar == QUOTE) { + if (curState == Quoted) { + if (isDoubleQuoteAt(line, curPosition)) { + leftOver.append(line.substring(startPosition, curPosition + 1)) + curPosition += 2 + startPosition = curPosition + } else { + curState = Unquoted + curPosition += 1 + } + } else { + curState = Quoted + curPosition += 1 + } + } else if (curChar == DELIM) { + if (curState == Unquoted && isDelimAt(line, curPosition) && curPosition > startPosition) { + leftOver.append(line.substring(startPosition, curPosition)) + tokens.append(stripQuotes(leftOver)) + leftOver.clear() + quotedLines = 0 + curPosition += delimLength + startPosition = curPosition + } else { + curPosition += 1 + } + } else if (curChar == NEWLINE) { + if (curState == Quoted && quotedLines < MAX_QUOTED_LINES) { + leftOver.append(line.substring(startPosition, curPosition + 1)) + line = inputIter.next() + '\n' + curPosition = 0 + startPosition = 0 + quotedLines += 1 + } else { + if (curPosition == startPosition) { + tokens.append(null) + } else { + leftOver.append(line.substring(startPosition, curPosition)) + tokens.append(stripQuotes(leftOver)) + } + leftOver.clear() + quotedLines = 0 + curPosition += 1 + } + } else if (curChar == BACKSLASH) { + curPosition += 2 + } else { + curPosition += 1 + } + } + tokens.toArray + } + +} + +object QuoteState extends Enumeration { + type State = Value + val Quoted, Unquoted = Value +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala new file mode 100644 index 000000000000..43359f34675a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.TestSQLContext._ + +class CsvSuite extends QueryTest { + import TestCsvData._ + + test("Simple CSV with header") { + val csvSchemaRDD = csvRDD(diamondCSVWithHeader, header = true) + csvSchemaRDD.registerAsTable("diamonds") + + checkAnswer( + sql("select cut from diamonds limit 3"), + Seq(Seq("Ideal"), Seq("Premium"), Seq("Good"))) + + checkAnswer( + sql("select count(*) from diamonds"), + 3 + ) + } + + test("Simple CSV without header") { + val csvSchemaRDD = csvRDD(salesCSVWithoutHeader, delimiter = "; ") + csvSchemaRDD.registerAsTable("sales") + + checkAnswer( + sql("select distinct V0 from sales"), + "2003" + ) + } + + test("Quoted CSV with new lines") { + val csvSchemaRDD = csvRDD(carCSVWithQuotes, delimiter = ", ", header = true) + csvSchemaRDD.registerAsTable("cars") + + checkAnswer( + sql("select Model from cars limit 1"), + """Ford + |Pampa""".stripMargin + ) + + checkAnswer( + sql("select distinct Make from cars"), + "Ford" + ) + } + + test("Custom quoted CSV with inner quotes") { + val csvSchemaRDD = csvRDD(salesCSVWithDoubleQuotes, delimiter = "; ", quote = '|') + csvSchemaRDD.registerAsTable("quotedSales") + + checkAnswer( + sql("select distinct V0 from quotedSales"), + "2003" + ) + + checkAnswer( + sql("select distinct V2 from quotedSales where V2 like '%Mac%'"), + """Mac "Power" Adapter""" + ) + + checkAnswer( + sql("select distinct V2 from quotedSales where V2 like '%iPad%'"), + """iPad |Power| Adapter""" + ) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala new file mode 100644 index 000000000000..b6c74fe30895 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.scalatest.FunSuite + +class CsvTokenizerSuite extends FunSuite { + + def runTokenizer( + tokenLines: List[List[String]], + delimiter: String = ",", + quote: Char = '"'): Iterator[List[String]] = { + val linesIter = tokenLines.map(_.mkString(delimiter)) + new CsvTokenizer(linesIter.iterator, delimiter, quote).map(_.toList) + } + + test("Tokenize simple delimited fields") { + val tokens = List("V1", "V2", "V3", "V4", "V5") + val line = tokens.mkString("; ") + val tokenizer = new CsvTokenizer(Seq(line).iterator, delimiter = "; ", quote = '"') + + assert(tokenizer.next().toList === tokens) + } + + test("Tokenize many lines of delimited fields") { + val seed = List("AAA", "BBB", "CCC", "DDD") + val tokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ * lineNumber + s"-$lineNumber") + } + assert(runTokenizer(tokenLines).toList === tokenLines) + } + + test("Tokenize many lines of delimited fields containing quotes") { + val seed = List("V1", "V2", "V3", "V4", "V5") + val tokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ + s"-$lineNumber") + } + + val quotedTokenLines = tokenLines.map { tokens => + tokens.map("\"" + _ + "\"") + } + + assert(runTokenizer(quotedTokenLines).toList === tokenLines) + } + + test("Tokenize many lines of delimited fields containing double quotes inside quotes") { + val seed = List("V1", "V2", "V3", "V4", "V5") + val tokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ + s""" ""quoted stuff""-$lineNumber""") + } + + val parsedTokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ + s""" "quoted stuff"-$lineNumber""") + } + + val quotedTokenLines = tokenLines.map { tokens => + tokens.map("\"" + _ + "\"") + } + + assert(runTokenizer(quotedTokenLines).toList === parsedTokenLines) + } + + test("Tokenize delimited fields containing new lines inside quotes") { + + val headerLine = "Make; Model; Year; Note" + val firstLine = """Honda; Civic; 2006; "Reliable""" + val secondLine = """ car"""" + val thirdLine = """Toyota; Camry; 2006; "Best""" + val fourthLine = "selling" + val fifthLine = """car"""" + + val tokenLines = List(headerLine, firstLine, secondLine, thirdLine, fourthLine, fifthLine) + val tokenizer = new CsvTokenizer(tokenLines.iterator, delimiter = "; ", quote = '"') + + val firstRow = tokenizer.next() + val secondRow = tokenizer.next() + val thirdRow = tokenizer.next() + + assert(firstRow.toSeq === Seq("Make", "Model", "Year", "Note")) + assert(secondRow(3) === """Reliable + | car""".stripMargin) + assert(thirdRow(3) === """Best + |selling + |car""".stripMargin) + + + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala new file mode 100644 index 000000000000..832c166b88ba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.apache.spark.sql.test.TestSQLContext + +object TestCsvData { + + val diamondCSVWithHeader = TestSQLContext.sparkContext.parallelize(Array( + "carat,cut,color,clarity,depth,table,price,x,y,z", + "0.23,Ideal,E,SI2,61.5,55,326,3.95,3.98,2.43", + "0.21,Premium,E,SI1,59.8,61,326,3.89,3.84,2.31", + "0.23,Good,E,VS1,56.9,65,327,4.05,4.07,2.31"), 1) + + val salesCSVWithoutHeader = TestSQLContext.sparkContext.parallelize(Array( + "2003; USA; MacBook; 11580.401776", + "2003; USA; Power Adapter; 4697.495022", + "2003; FRA; MacBook;5910; 787393", + "2003; FRA; Power Adapter; 2758.903949", + "2003; RUS; MacBook; 6992; 729325" + )) + + val carCSVWithQuotes = TestSQLContext.sparkContext.parallelize(Array( + """Year, Make, Model, Desc""", + """"1964", Ford, "Ford""", +"""Pampa", manufactured by Ford of Brazil""", + """"1947", "Ford", "Ford Pilot", "The Pilot was the first large""", +"""post-war British Ford"""", + """1997, Ford, Mustang, """ + ), 1) + + val salesCSVWithDoubleQuotes = TestSQLContext.sparkContext.parallelize(Array( + """|2003|; USA; |Mac "Power" Adapter|; 4697.495022""", + """|2003|; FRA; |iPhone "Power" Adapter|; 2758.903949""", + """|2003|; FRA; |iPad ||Power|| Adapter|; 2758.903949""" + ), 1) + +} +