From 177eb062f6bdbf0f59d8a29ce1ff620ec8fe5ec9 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 7 Jul 2014 19:12:54 -0700 Subject: [PATCH 01/20] Basic version of csv parsing --- .../org/apache/spark/sql/SQLContext.scala | 17 ++++ .../org/apache/spark/sql/csv/CsvRDD.scala | 90 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 7edb548678c3..812d74a3a410 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ +import org.apache.spark.sql.csv._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.SparkContext @@ -129,6 +130,22 @@ class SQLContext(@transient val sparkContext: SparkContext) def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + + def csvFile(path: String, + delimiter: String = ",", + useHeader: Boolean = false, + quote: String = "\""): SchemaRDD = { + val csv = sparkContext.textFile(path) + csvRdd(csv, delimiter, useHeader) + } + + def csvRdd(csv: RDD[String], + delimiter: String = ",", + useHeader: Boolean = false, + quote: String = "\""): SchemaRDD = { + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, useHeader)) + } + /** * :: Experimental :: * Creates an empty parquet file with the schema of class `A`, which can be registered as a table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala new file mode 100644 index 000000000000..eb170876a801 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, AttributeReference, Row} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.Logging + +private[sql] object CsvRDD extends Logging { + + private[sql] def inferSchema( + csv: RDD[String], + delimiter: String = ",", + useHeader: Boolean = false): LogicalPlan = { + + // TODO: Read header. For now assume there is no header + // TODO: What if first row is not representative + val firstLine = csv.first() + val firstRow = firstLine.split(delimiter) + val header = if (useHeader) { + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"V$index" } + } + + // TODO: Infer types based on a sample + // TODO: Figure out a way for user to specify types/schema + val fields = header.map( fieldName => StructField(fieldName, StringType, nullable = true)) + val schema = StructType(fields) + val parsedCSV = csv.mapPartitions { iter => + val csvIter = if (useHeader) { + // Any input line that equals the headerLine is assumed to be header and filtered + iter.filter(_ != firstLine) + } else { + iter + } + parseCSV(csvIter, delimiter, schema) + } + + SparkLogicalPlan(ExistingRdd(asAttributes(schema), parsedCSV)) + } + + private def castToType(value: Any, dataType: DataType): Any = dataType match { + case StringType => value.asInstanceOf[String] + case BooleanType => value.asInstanceOf[Boolean] + case DoubleType => value.asInstanceOf[Double] + case FloatType => value.asInstanceOf[Float] + case IntegerType => value.asInstanceOf[Int] + case LongType => value.asInstanceOf[Long] + case ShortType => value.asInstanceOf[Short] + case _ => null + } + + private def parseCSV(iter: Iterator[String], + delimiter: String, + schema: StructType): Iterator[Row] = { + val row = new GenericMutableRow(schema.fields.length) + iter.map { line => + val tokens = line.split(delimiter) + schema.fields.zipWithIndex.foreach { + case (StructField(name, dataType, _), index) => + row.update(index, castToType(tokens(index), dataType)) + } + row + } + } + + private def asAttributes(struct: StructType): Seq[AttributeReference] = { + struct.fields.map(field => AttributeReference(field.name, field.dataType, nullable = true)()) + } + +} \ No newline at end of file From 510df2ef8bec5a963892af49cfcc88716abf9c93 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 19:19:58 -0700 Subject: [PATCH 02/20] RFC 4180 compatible tokenizer --- .../org/apache/spark/sql/csv/CsvRDD.scala | 33 ++--- .../apache/spark/sql/csv/CsvTokenizer.scala | 119 ++++++++++++++++++ 2 files changed, 136 insertions(+), 16 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala index eb170876a801..57d0a7bb8035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -28,31 +28,35 @@ private[sql] object CsvRDD extends Logging { private[sql] def inferSchema( csv: RDD[String], - delimiter: String = ",", - useHeader: Boolean = false): LogicalPlan = { + delimiter: String, + quote: String, + useHeader: Boolean): LogicalPlan = { - // TODO: Read header. For now assume there is no header - // TODO: What if first row is not representative + // Constructing schema + // TODO: Infer types based on a sample and/or let user specify types/schema val firstLine = csv.first() - val firstRow = firstLine.split(delimiter) + // Assuming first row is representative and using it to determine number of fields + val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() val header = if (useHeader) { firstRow } else { firstRow.zipWithIndex.map { case (value, index) => s"V$index" } } - // TODO: Infer types based on a sample - // TODO: Figure out a way for user to specify types/schema - val fields = header.map( fieldName => StructField(fieldName, StringType, nullable = true)) - val schema = StructType(fields) + val schemaFields = header.map { fieldName => + StructField(fieldName.asInstanceOf[String], StringType, nullable = true) + } + val schema = StructType(schemaFields) + val parsedCSV = csv.mapPartitions { iter => + // When using header, any input line that equals firstLine is assumed to be header val csvIter = if (useHeader) { - // Any input line that equals the headerLine is assumed to be header and filtered iter.filter(_ != firstLine) } else { iter } - parseCSV(csvIter, delimiter, schema) + val tokenIter = new CsvTokenizer(csvIter, delimiter, quote) + parseCSV(tokenIter, schema) } SparkLogicalPlan(ExistingRdd(asAttributes(schema), parsedCSV)) @@ -69,12 +73,9 @@ private[sql] object CsvRDD extends Logging { case _ => null } - private def parseCSV(iter: Iterator[String], - delimiter: String, - schema: StructType): Iterator[Row] = { + private def parseCSV(iter: Iterator[Array[Any]], schema: StructType): Iterator[Row] = { val row = new GenericMutableRow(schema.fields.length) - iter.map { line => - val tokens = line.split(delimiter) + iter.map { tokens => schema.fields.zipWithIndex.foreach { case (StructField(name, dataType, _), index) => row.update(index, castToType(tokens(index), dataType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala new file mode 100644 index 000000000000..1d5cc43fb528 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.Logging + +/** + * Tokenizer based on RFC 4180 for comma separated values. + * It implements an iterator that returns each tokenized line as an Array[Any]. + */ +private[sql] class CsvTokenizer(inputIter: Iterator[String], + delimiter: String, quote: String) extends Iterator[Array[Any]] with Logging { + + private val DELIM = delimiter.charAt(0) + private val QUOTE = quote.charAt(0) + private val BACKSLASH = '\\' + private val NEWLINE = '\n' + + private val delimLength = delimiter.length + private val quoteLength = quote.length + + private def isDelimAt(inStr: String, index: Int): Boolean = { + inStr.substring(index, index + delimLength) == delimiter + } + + private def isQuoteAt(inStr: String, index: Int): Boolean = { + inStr.substring(index, index + quoteLength) == quote + } + + private def stripQuotes(inStr: String): String = if (inStr.startsWith(quote)) { + inStr.stripPrefix(quote).stripSuffix(quote) + } else { + inStr + } + + import QuoteState._ + + def hasNext: Boolean = inputIter.hasNext + + def next(): Array[Any] = { + var curState = Unquoted + var curPosition = 0 + var startPosition = 0 + var curChar: Char = '\0' + var leftOver: String = "" // Used to keep track of tokens that span multiple lines + val tokens = new ArrayBuffer[Any]() + var line = inputIter.next() + '\n' + + while (curPosition < line.length) { + curChar = line.charAt(curPosition) + + (curState, curChar) match { + case (Quoted, QUOTE) => + if (isQuoteAt(line, curPosition)) { + curState = Unquoted + curPosition += quoteLength + } else { + curPosition += 1 + } + case (Quoted, NEWLINE) if inputIter.hasNext => + leftOver = leftOver + line.substring(startPosition, curPosition + 1) + line = inputIter.next() + '\n' + curPosition = 0 + startPosition = 0 + case (Unquoted, DELIM) => + if (isDelimAt(line, curPosition) && curPosition > startPosition) { + tokens.append(stripQuotes(leftOver + line.substring(startPosition, curPosition))) + curPosition += delimLength + startPosition = curPosition + leftOver = "" + } else { + curPosition += 1 + } + case (Unquoted, QUOTE) => + if (isQuoteAt(line, curPosition)) { + curState = Quoted + curPosition += quoteLength + } else { + curPosition += 1 + } + case (Unquoted, NEWLINE) => + if (startPosition == curPosition) { + tokens.append(null) + } else { + tokens.append(stripQuotes(leftOver + line.substring(startPosition, curPosition))) + } + curPosition += 1 + case (_, BACKSLASH) => + curPosition += 2 + case (_, _) => + curPosition += 1 + } + } + tokens.toArray + } + +} + +object QuoteState extends Enumeration { + type State = Value + val Quoted, Unquoted = Value +} \ No newline at end of file From 30a5ae52983c70816330a9c467c0d7a125564529 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 19:20:21 -0700 Subject: [PATCH 03/20] Added API documentation --- .../org/apache/spark/sql/SQLContext.scala | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 812d74a3a410..9335fcd55317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -130,20 +130,45 @@ class SQLContext(@transient val sparkContext: SparkContext) def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) - + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[SchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param delimiter Optional delimiter (default is comma) + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + * @param quote Optional quote character or string (default is '"') + */ def csvFile(path: String, delimiter: String = ",", - useHeader: Boolean = false, - quote: String = "\""): SchemaRDD = { + quote: String = "\"", + header: Boolean = false): SchemaRDD = { val csv = sparkContext.textFile(path) - csvRdd(csv, delimiter, useHeader) + csvRDD(csv, delimiter, quote, header) } - def csvRdd(csv: RDD[String], + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[SchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[SparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param delimiter Optional delimiter (default is comma) + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + * @param quote Optional quote character of strig (default is '"') + */ + def csvRDD(csv: RDD[String], delimiter: String = ",", - useHeader: Boolean = false, - quote: String = "\""): SchemaRDD = { - new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, useHeader)) + quote: String = "\"", + header: Boolean = false): SchemaRDD = { + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, header)) } /** From ac95fcba25ab5175c58d6833c5d35bf7ac652370 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 19:20:58 -0700 Subject: [PATCH 04/20] Added unit tests --- .../org/apache/spark/sql/csv/CsvSuite.scala | 87 +++++++++++++++++++ .../apache/spark/sql/csv/TestCsvData.scala | 53 +++++++++++ 2 files changed, 140 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala new file mode 100644 index 000000000000..0288784f826e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.TestSQLContext._ + +class CsvSuite extends QueryTest { + import TestCsvData._ + + test("Simple CSV with header") { + val csvSchemaRDD = csvRDD(diamondCSVWithHeader, header = true) + csvSchemaRDD.registerAsTable("diamonds") + + checkAnswer( + sql("select cut from diamonds limit 3"), + Seq(Seq("Ideal"), Seq("Premium"), Seq("Good"))) + + checkAnswer( + sql("select count(*) from diamonds"), + 3 + ) + } + + test("Simple CSV without header") { + val csvSchemaRDD = csvRDD(salesCSVWithoutHeader, delimiter = "; ") + csvSchemaRDD.registerAsTable("sales") + + checkAnswer( + sql("select distinct V0 from sales"), + "2003" + ) + } + + test("Quoted CSV with new lines") { + val csvSchemaRDD = csvRDD(carCSVWithQuotes, delimiter = ", ", header = true) + csvSchemaRDD.registerAsTable("cars") + + checkAnswer( + sql("select Model from cars limit 1"), + """Ford + |Pampa""".stripMargin + ) + + checkAnswer( + sql("select distinct Make from cars"), + "Ford" + ) + } + + test("Custom quoted CSV with inner quotes") { + val csvSchemaRDD = csvRDD(salesCSVWithDoubleQuotes, + delimiter = "; ", + quote = "|") + csvSchemaRDD.registerAsTable("quotedSales") + + checkAnswer( + sql("select distinct V0 from quotedSales"), + "2003" + ) + + checkAnswer( + sql("select distinct V2 from quotedSales where V2 like '%Mac%'"), + """Mac "Power" Adapter""" + ) + + checkAnswer( + sql("select distinct V2 from quotedSales where V2 like '%iPad%'"), + """iPad ||Power|| Adapter""" + ) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala new file mode 100644 index 000000000000..84565200e65c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.apache.spark.sql.test.TestSQLContext + +object TestCsvData { + + val diamondCSVWithHeader = TestSQLContext.sparkContext.parallelize(Array( + "carat,cut,color,clarity,depth,table,price,x,y,z", + "0.23,Ideal,E,SI2,61.5,55,326,3.95,3.98,2.43", + "0.21,Premium,E,SI1,59.8,61,326,3.89,3.84,2.31", + "0.23,Good,E,VS1,56.9,65,327,4.05,4.07,2.31"), 1) + + val salesCSVWithoutHeader = TestSQLContext.sparkContext.parallelize(Array( + "2003; USA; MacBook; 11580.401776", + "2003; USA; Power Adapter; 4697.495022", + "2003; FRA; MacBook;5910; 787393", + "2003; FRA; Power Adapter; 2758.903949", + "2003; RUS; MacBook; 6992; 729325" + )) + + val carCSVWithQuotes = TestSQLContext.sparkContext.parallelize(Array( + """Year, Make, Model, Desc""", + """"1964", Ford, "Ford""", +"""Pampa", manufactured by Ford of Brazil""", + """"1947", "Ford", "Ford Pilot", "The Pilot was the first large""", +"""post-war British Ford"""", + """1997, Ford, Mustang, """ + ), 1) + + val salesCSVWithDoubleQuotes = TestSQLContext.sparkContext.parallelize(Array( + """|2003|; USA; |Mac "Power" Adapter|; 4697.495022""", + """|2003|; FRA; |iPhone "Power" Adapter|; 2758.903949""", + """|2003|; FRA; |iPad ||Power|| Adapter|; 2758.903949""" + ), 1) + +} \ No newline at end of file From 95a7a1ac49af6e87cc2b85f1f186d1a88ad47cc3 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 19:24:15 -0700 Subject: [PATCH 05/20] Organized imports --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9335fcd55317..eb43f4dff4c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,10 +33,10 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.csv._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ -import org.apache.spark.sql.csv._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.SparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala index 1d5cc43fb528..9af0c1801f0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.csv import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.Logging - /** * Tokenizer based on RFC 4180 for comma separated values. * It implements an iterator that returns each tokenized line as an Array[Any]. */ private[sql] class CsvTokenizer(inputIter: Iterator[String], - delimiter: String, quote: String) extends Iterator[Array[Any]] with Logging { + delimiter: String, quote: String) extends Iterator[Array[Any]] { private val DELIM = delimiter.charAt(0) private val QUOTE = quote.charAt(0) From 65f7e95da29cb603af3c0919bebeb0d4e8ef1dcb Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 19:32:35 -0700 Subject: [PATCH 06/20] Style cleanup --- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index eb43f4dff4c2..03a831521d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -139,9 +139,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @param path path to input file * @param delimiter Optional delimiter (default is comma) - * @param header Optional flag to indicate first line of each file is the header - * (default is false) * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) */ def csvFile(path: String, delimiter: String = ",", @@ -160,9 +160,9 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @param csv input RDD * @param delimiter Optional delimiter (default is comma) - * @param header Optional flag to indicate first line of each file is the hader - * (default is false) * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) */ def csvRDD(csv: RDD[String], delimiter: String = ",", From b5eae3165c79d4cb6fa0ded1f70a5ff18213abe0 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 23:08:00 -0700 Subject: [PATCH 07/20] Added Java API --- .../spark/sql/api/java/JavaSQLContext.scala | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 790d9ef22cf1..b9f8f99a4555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.csv.CsvRDD import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} @@ -122,6 +123,47 @@ class JavaSQLContext(val sqlContext: SQLContext) { def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[JavaSchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile(path: String, + delimiter: String = ",", + quote: String = "\"", + header: Boolean = false): JavaSchemaRDD = { + val csv = sqlContext.sparkContext.textFile(path) + csvRDD(csv, delimiter, quote, header) + } + + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[JavaSchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[JavaSparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD(csv: JavaRDD[String], + delimiter: String = ",", + quote: String = "\"", + header: Boolean = false): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, header)) + } + /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only * during the lifetime of this instance of SQLContext. From 44fe059138d0f532606aa3e57251c7179277a952 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 9 Jul 2014 23:10:59 -0700 Subject: [PATCH 08/20] Style --- sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala | 2 +- .../src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala index 57d0a7bb8035..a1b3d942c6eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -87,5 +87,5 @@ private[sql] object CsvRDD extends Logging { private def asAttributes(struct: StructType): Seq[AttributeReference] = { struct.fields.map(field => AttributeReference(field.name, field.dataType, nullable = true)()) } +} -} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala index 9af0c1801f0b..29d84a7fc9cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -114,4 +114,5 @@ private[sql] class CsvTokenizer(inputIter: Iterator[String], object QuoteState extends Enumeration { type State = Value val Quoted, Unquoted = Value -} \ No newline at end of file +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala index 0288784f826e..ce6b0ee91c28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala @@ -84,4 +84,5 @@ class CsvSuite extends QueryTest { """iPad ||Power|| Adapter""" ) } -} \ No newline at end of file +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala index 84565200e65c..832c166b88ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/TestCsvData.scala @@ -50,4 +50,5 @@ object TestCsvData { """|2003|; FRA; |iPad ||Power|| Adapter|; 2758.903949""" ), 1) -} \ No newline at end of file +} + From 70b6018efcebb97ef77171de07c86c0311e480f2 Mon Sep 17 00:00:00 2001 From: Hossein Date: Thu, 10 Jul 2014 14:28:09 -0700 Subject: [PATCH 09/20] Added python bindings --- python/pyspark/sql.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 5051c82da32a..dd355ea05229 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -184,6 +184,45 @@ def func(split, iterator): jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) return SchemaRDD(jschema_rdd, self) + def csvFile(self, path, delimiter = ",", quote = "\"", header = False): + """Loads a CSV file (according to RFC 4180) and returns the result as a L{SchemaRDD}. + header flag specified if first line of each file should be treated as header. + + NOTE: If there are new line characters inside quoted fields this method may fail to + parse correctly, because the two lines may be in different partitions. Use + L{SQLContext#csvRDD} to parse such files. + + >>> csv = sqlCtx.csvFile('/tmp/csvDataFiles', delimiter = ", ", header = True) + >>> sqlCtx.registerRDDAsTable(csv, "csvTable") + >>> csvRes = sqlCtx.sql("SELECT * FROM csvTable limit 10") + True + """ + jschema_rdd = self._ssql_ctx.csvFile(path, delimiter, quote, header) + return SchemaRDD(jschema_rdd, self) + + def csvRDD(self, rdd, delimiter = ",", quote = "\"", header = False): + """Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + L{SchemaRDD}. + + NOTE: If there are new line characters inside quoted fields, use wholeTextFile to + read each file into a single partition. + + >>> rdd = sc.textFile("/tmp/csvDataFiles") + >>> csv = sqlCtx.csvRDD(rdd, delimiter = ", ", header = True) + >>> sqlCtx.registerRDDAsTable(csv, "csvTable") + >>> csvRes = sqlCtx.sql("SELECT * FROM csvTable limit 10") + """ + def func(split, iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + yield x.encode("utf-8") + keyed = PipelinedRDD(rdd, func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._jvm.BytesToString()) + jschema_rdd = self._ssql_ctx.csvRDD(jrdd.rdd(), delimiter, quote, header) + return SchemaRDD(jschema_rdd, self) + def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. From 1409e44fdb03aac94ed997ec7a9e03ef4ead18c9 Mon Sep 17 00:00:00 2001 From: Hossein Date: Thu, 10 Jul 2014 14:29:28 -0700 Subject: [PATCH 10/20] Applied style comments --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 6 ++++-- .../org/apache/spark/sql/api/java/JavaSQLContext.scala | 6 ++++-- .../main/scala/org/apache/spark/sql/csv/CsvRDD.scala | 2 +- .../scala/org/apache/spark/sql/csv/CsvTokenizer.scala | 10 ++++++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 03a831521d71..b9b1bc9851eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -143,7 +143,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param header Optional flag to indicate first line of each file is the header * (default is false) */ - def csvFile(path: String, + def csvFile( + path: String, delimiter: String = ",", quote: String = "\"", header: Boolean = false): SchemaRDD = { @@ -164,7 +165,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param header Optional flag to indicate first line of each file is the hader * (default is false) */ - def csvRDD(csv: RDD[String], + def csvRDD( + csv: RDD[String], delimiter: String = ",", quote: String = "\"", header: Boolean = false): SchemaRDD = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index b9f8f99a4555..a7b8da97a7f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -136,7 +136,8 @@ class JavaSQLContext(val sqlContext: SQLContext) { * @param header Optional flag to indicate first line of each file is the header * (default is false) */ - def csvFile(path: String, + def csvFile( + path: String, delimiter: String = ",", quote: String = "\"", header: Boolean = false): JavaSchemaRDD = { @@ -157,7 +158,8 @@ class JavaSQLContext(val sqlContext: SQLContext) { * @param header Optional flag to indicate first line of each file is the hader * (default is false) */ - def csvRDD(csv: JavaRDD[String], + def csvRDD( + csv: JavaRDD[String], delimiter: String = ",", quote: String = "\"", header: Boolean = false): JavaSchemaRDD = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala index a1b3d942c6eb..a470ea124390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -73,7 +73,7 @@ private[sql] object CsvRDD extends Logging { case _ => null } - private def parseCSV(iter: Iterator[Array[Any]], schema: StructType): Iterator[Row] = { + private def parseCSV(iter: Iterator[Array[String]], schema: StructType): Iterator[Row] = { val row = new GenericMutableRow(schema.fields.length) iter.map { tokens => schema.fields.zipWithIndex.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala index 29d84a7fc9cf..11d22d44f7f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -23,8 +23,10 @@ import scala.collection.mutable.ArrayBuffer * Tokenizer based on RFC 4180 for comma separated values. * It implements an iterator that returns each tokenized line as an Array[Any]. */ -private[sql] class CsvTokenizer(inputIter: Iterator[String], - delimiter: String, quote: String) extends Iterator[Array[Any]] { +private[sql] class CsvTokenizer( + inputIter: Iterator[String], + delimiter: String, + quote: String) extends Iterator[Array[String]] { private val DELIM = delimiter.charAt(0) private val QUOTE = quote.charAt(0) @@ -52,13 +54,13 @@ private[sql] class CsvTokenizer(inputIter: Iterator[String], def hasNext: Boolean = inputIter.hasNext - def next(): Array[Any] = { + def next(): Array[String] = { var curState = Unquoted var curPosition = 0 var startPosition = 0 var curChar: Char = '\0' var leftOver: String = "" // Used to keep track of tokens that span multiple lines - val tokens = new ArrayBuffer[Any]() + val tokens = new ArrayBuffer[String]() var line = inputIter.next() + '\n' while (curPosition < line.length) { From 07b6a74f4e116e328125ec6298914e29d6b63d6b Mon Sep 17 00:00:00 2001 From: Hossein Date: Sat, 12 Jul 2014 17:39:48 -0700 Subject: [PATCH 11/20] Avoind pattern matching and string mutation for efficiency --- .../apache/spark/sql/csv/CsvTokenizer.scala | 93 ++++++++++--------- 1 file changed, 50 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala index 11d22d44f7f5..23b906925d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -26,28 +26,31 @@ import scala.collection.mutable.ArrayBuffer private[sql] class CsvTokenizer( inputIter: Iterator[String], delimiter: String, - quote: String) extends Iterator[Array[String]] { + quote: Char) extends Iterator[Array[String]] { - private val DELIM = delimiter.charAt(0) - private val QUOTE = quote.charAt(0) - private val BACKSLASH = '\\' - private val NEWLINE = '\n' + private val DELIM = delimiter.charAt(0) + private val QUOTE = quote + private val DOUBLEQUOTE = quote.toString * 2 + private val BACKSLASH = '\\' + private val NEWLINE = '\n' private val delimLength = delimiter.length - private val quoteLength = quote.length private def isDelimAt(inStr: String, index: Int): Boolean = { inStr.substring(index, index + delimLength) == delimiter } - private def isQuoteAt(inStr: String, index: Int): Boolean = { - inStr.substring(index, index + quoteLength) == quote + private def isDoubleQuoteAt(inStr: String, index: Int): Boolean = { + inStr.substring(index, index + 2) == DOUBLEQUOTE } - private def stripQuotes(inStr: String): String = if (inStr.startsWith(quote)) { - inStr.stripPrefix(quote).stripSuffix(quote) - } else { - inStr + private def stripQuotes(inStr: StringBuilder): String = { + val end = inStr.length - 1 + if (inStr.charAt(0) == QUOTE && inStr.charAt(end) == QUOTE) { + inStr.deleteCharAt(end).deleteCharAt(0).toString() + } else { + inStr.toString() + } } import QuoteState._ @@ -59,53 +62,57 @@ private[sql] class CsvTokenizer( var curPosition = 0 var startPosition = 0 var curChar: Char = '\0' - var leftOver: String = "" // Used to keep track of tokens that span multiple lines + val leftOver = new StringBuilder() // Used to keep track of tokens that span multiple lines val tokens = new ArrayBuffer[String]() var line = inputIter.next() + '\n' while (curPosition < line.length) { curChar = line.charAt(curPosition) - (curState, curChar) match { - case (Quoted, QUOTE) => - if (isQuoteAt(line, curPosition)) { - curState = Unquoted - curPosition += quoteLength + if (curChar == QUOTE) { + if (curState == Quoted) { + if (isDoubleQuoteAt(line, curPosition)) { + leftOver.append(line.substring(startPosition, curPosition + 1)) + curPosition += 2 + startPosition = curPosition } else { + curState = Unquoted curPosition += 1 } - case (Quoted, NEWLINE) if inputIter.hasNext => - leftOver = leftOver + line.substring(startPosition, curPosition + 1) + } else { + curState = Quoted + curPosition += 1 + } + } else if (curChar == DELIM) { + if (curState == Unquoted && isDelimAt(line, curPosition) && curPosition > startPosition) { + leftOver.append(line.substring(startPosition, curPosition)) + tokens.append(stripQuotes(leftOver)) + leftOver.clear() + curPosition += delimLength + startPosition = curPosition + } else { + curPosition += 1 + } + } else if (curChar == NEWLINE) { + if (curState == Quoted) { + leftOver.append(line.substring(startPosition, curPosition + 1)) line = inputIter.next() + '\n' curPosition = 0 startPosition = 0 - case (Unquoted, DELIM) => - if (isDelimAt(line, curPosition) && curPosition > startPosition) { - tokens.append(stripQuotes(leftOver + line.substring(startPosition, curPosition))) - curPosition += delimLength - startPosition = curPosition - leftOver = "" - } else { - curPosition += 1 - } - case (Unquoted, QUOTE) => - if (isQuoteAt(line, curPosition)) { - curState = Quoted - curPosition += quoteLength - } else { - curPosition += 1 - } - case (Unquoted, NEWLINE) => - if (startPosition == curPosition) { + } else { + if (curPosition == startPosition) { tokens.append(null) } else { - tokens.append(stripQuotes(leftOver + line.substring(startPosition, curPosition))) + leftOver.append(line.substring(startPosition, curPosition)) + tokens.append(stripQuotes(leftOver)) } + leftOver.clear() curPosition += 1 - case (_, BACKSLASH) => - curPosition += 2 - case (_, _) => - curPosition += 1 + } + } else if (curChar == BACKSLASH) { + curPosition += 2 + } else { + curPosition += 1 } } tokens.toArray From f7935c059846d09e635c3eca1c3040dba94e16c1 Mon Sep 17 00:00:00 2001 From: Hossein Date: Sat, 12 Jul 2014 17:40:55 -0700 Subject: [PATCH 12/20] Applying comments --- python/pyspark/sql.py | 45 ++++++++++++------- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../spark/sql/api/java/JavaSQLContext.scala | 4 +- .../org/apache/spark/sql/csv/CsvRDD.scala | 26 ++++++++--- 4 files changed, 54 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 5cf750997fb1..5c68d80c783f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -189,32 +189,41 @@ def func(split, iterator): return SchemaRDD(jschema_rdd, self) def csvFile(self, path, delimiter = ",", quote = "\"", header = False): - """Loads a CSV file (according to RFC 4180) and returns the result as a L{SchemaRDD}. - header flag specified if first line of each file should be treated as header. + """ + Loads a CSV file (according to RFC 4180) and returns the result as a L{SchemaRDD}. + header flag specified if first line of each file should be treated as header. - NOTE: If there are new line characters inside quoted fields this method may fail to - parse correctly, because the two lines may be in different partitions. Use - L{SQLContext#csvRDD} to parse such files. + NOTE: If there are new line characters inside quoted fields this method may fail to + parse correctly, because the two lines may be in different partitions. Use + L{SQLContext#csvRDD} to parse such files. - >>> csv = sqlCtx.csvFile('/tmp/csvDataFiles', delimiter = ", ", header = True) + >>> import tempfile, shutil + >>> csvFile = tempfile.mkdtemp() + >>> shutil.rmtree(jsonFile) + >>> ofn = open(csvFile, 'w') + >>> for csvStr in csvStrings: + ... print>>ofn, csvStr + >>> ofn.close() + >>> csv = sqlCtx.csvFile(csvFile, delimiter = ", ", header = True) >>> sqlCtx.registerRDDAsTable(csv, "csvTable") - >>> csvRes = sqlCtx.sql("SELECT * FROM csvTable limit 10") + >>> csvRes = sqlCtx.sql("SELECT Year FROM csvTable limit 1") True """ jschema_rdd = self._ssql_ctx.csvFile(path, delimiter, quote, header) return SchemaRDD(jschema_rdd, self) def csvRDD(self, rdd, delimiter = ",", quote = "\"", header = False): - """Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a - L{SchemaRDD}. + """ + Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + L{SchemaRDD}. - NOTE: If there are new line characters inside quoted fields, use wholeTextFile to - read each file into a single partition. + NOTE: If there are new line characters inside quoted fields, use wholeTextFile to + read each file into a single partition. - >>> rdd = sc.textFile("/tmp/csvDataFiles") - >>> csv = sqlCtx.csvRDD(rdd, delimiter = ", ", header = True) - >>> sqlCtx.registerRDDAsTable(csv, "csvTable") - >>> csvRes = sqlCtx.sql("SELECT * FROM csvTable limit 10") + >>> csvrdd = sqlCtx.csvRDD(csv, delimiter = ", ", header = True) + >>> sqlCtx.registerRDDAsTable(csvrdd, "csvTable2") + >>> csvRes = sqlCtx.sql("SELECT count(*) FROM csvTable2") + True """ def func(split, iterator): for x in iterator: @@ -539,6 +548,12 @@ def _test(): '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) + csvStrings = ['Year,Make,Model,Description' + '"1997","Ford","E350",', + '2000,Mercury,"Cougar", "Really ""Good"" car"', + '2007,Honda,"Civic",'] + globs['csvString'] = csvStrings + globs['csv'] = sc.parallelize(csvStrings) globs['nestedRdd1'] = sc.parallelize([ {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 57dc93ab8a41..84395657e2e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -146,7 +146,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def csvFile( path: String, delimiter: String = ",", - quote: String = "\"", + quote: Char = '"', header: Boolean = false): SchemaRDD = { val csv = sparkContext.textFile(path) csvRDD(csv, delimiter, quote, header) @@ -168,7 +168,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def csvRDD( csv: RDD[String], delimiter: String = ",", - quote: String = "\"", + quote: Char = '"', header: Boolean = false): SchemaRDD = { new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, header)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index a7b8da97a7f2..9f22fe3f155a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -139,7 +139,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { def csvFile( path: String, delimiter: String = ",", - quote: String = "\"", + quote: Char = '"', header: Boolean = false): JavaSchemaRDD = { val csv = sqlContext.sparkContext.textFile(path) csvRDD(csv, delimiter, quote, header) @@ -161,7 +161,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { def csvRDD( csv: JavaRDD[String], delimiter: String = ",", - quote: String = "\"", + quote: Char = '"', header: Boolean = false): JavaSchemaRDD = { new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, header)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala index a470ea124390..9e8823d44151 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -26,10 +26,18 @@ import org.apache.spark.sql.Logging private[sql] object CsvRDD extends Logging { + /** + * Infers schema of a CSV file. It uses the first row of the first partition to + * infer number of columns. If header flag is set, all lines that equal the first line + * are filtered before parsing. + * + * o If a line contains fewer tokens than the schema, it is padded with nulls + * o If a line has more tokens than the schema, extra tokens are ignored. + */ private[sql] def inferSchema( csv: RDD[String], delimiter: String, - quote: String, + quote: Char, useHeader: Boolean): LogicalPlan = { // Constructing schema @@ -37,7 +45,10 @@ private[sql] object CsvRDD extends Logging { val firstLine = csv.first() // Assuming first row is representative and using it to determine number of fields val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() + val numFields = firstRow.length + logger.info(s"Parsing CSV with $numFields.") val header = if (useHeader) { + logger.info(s"Using header line: $firstLine") firstRow } else { firstRow.zipWithIndex.map { case (value, index) => s"V$index" } @@ -46,7 +57,7 @@ private[sql] object CsvRDD extends Logging { val schemaFields = header.map { fieldName => StructField(fieldName.asInstanceOf[String], StringType, nullable = true) } - val schema = StructType(schemaFields) + val row = new GenericMutableRow(numFields) val parsedCSV = csv.mapPartitions { iter => // When using header, any input line that equals firstLine is assumed to be header @@ -56,9 +67,10 @@ private[sql] object CsvRDD extends Logging { iter } val tokenIter = new CsvTokenizer(csvIter, delimiter, quote) - parseCSV(tokenIter, schema) + parseCSV(tokenIter, schemaFields, row) } + val schema = StructType(schemaFields) SparkLogicalPlan(ExistingRdd(asAttributes(schema), parsedCSV)) } @@ -73,10 +85,12 @@ private[sql] object CsvRDD extends Logging { case _ => null } - private def parseCSV(iter: Iterator[Array[String]], schema: StructType): Iterator[Row] = { - val row = new GenericMutableRow(schema.fields.length) + private def parseCSV( + iter: Iterator[Array[String]], + schemaFields: Seq[StructField], + row: GenericMutableRow): Iterator[Row] = { iter.map { tokens => - schema.fields.zipWithIndex.foreach { + schemaFields.zipWithIndex.foreach { case (StructField(name, dataType, _), index) => row.update(index, castToType(tokens(index), dataType)) } From f3f0576c78470c5e368f87e7d7228d63d0d8ca5f Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 15 Jul 2014 00:34:15 -0700 Subject: [PATCH 13/20] Adding unit tests for CsvTokenizer --- .../spark/sql/csv/CsvTokenizerSuite.scala | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala new file mode 100644 index 000000000000..b6c74fe30895 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvTokenizerSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.csv + +import org.scalatest.FunSuite + +class CsvTokenizerSuite extends FunSuite { + + def runTokenizer( + tokenLines: List[List[String]], + delimiter: String = ",", + quote: Char = '"'): Iterator[List[String]] = { + val linesIter = tokenLines.map(_.mkString(delimiter)) + new CsvTokenizer(linesIter.iterator, delimiter, quote).map(_.toList) + } + + test("Tokenize simple delimited fields") { + val tokens = List("V1", "V2", "V3", "V4", "V5") + val line = tokens.mkString("; ") + val tokenizer = new CsvTokenizer(Seq(line).iterator, delimiter = "; ", quote = '"') + + assert(tokenizer.next().toList === tokens) + } + + test("Tokenize many lines of delimited fields") { + val seed = List("AAA", "BBB", "CCC", "DDD") + val tokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ * lineNumber + s"-$lineNumber") + } + assert(runTokenizer(tokenLines).toList === tokenLines) + } + + test("Tokenize many lines of delimited fields containing quotes") { + val seed = List("V1", "V2", "V3", "V4", "V5") + val tokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ + s"-$lineNumber") + } + + val quotedTokenLines = tokenLines.map { tokens => + tokens.map("\"" + _ + "\"") + } + + assert(runTokenizer(quotedTokenLines).toList === tokenLines) + } + + test("Tokenize many lines of delimited fields containing double quotes inside quotes") { + val seed = List("V1", "V2", "V3", "V4", "V5") + val tokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ + s""" ""quoted stuff""-$lineNumber""") + } + + val parsedTokenLines = (1 to 10).toList.map { lineNumber => + seed.map(_ + s""" "quoted stuff"-$lineNumber""") + } + + val quotedTokenLines = tokenLines.map { tokens => + tokens.map("\"" + _ + "\"") + } + + assert(runTokenizer(quotedTokenLines).toList === parsedTokenLines) + } + + test("Tokenize delimited fields containing new lines inside quotes") { + + val headerLine = "Make; Model; Year; Note" + val firstLine = """Honda; Civic; 2006; "Reliable""" + val secondLine = """ car"""" + val thirdLine = """Toyota; Camry; 2006; "Best""" + val fourthLine = "selling" + val fifthLine = """car"""" + + val tokenLines = List(headerLine, firstLine, secondLine, thirdLine, fourthLine, fifthLine) + val tokenizer = new CsvTokenizer(tokenLines.iterator, delimiter = "; ", quote = '"') + + val firstRow = tokenizer.next() + val secondRow = tokenizer.next() + val thirdRow = tokenizer.next() + + assert(firstRow.toSeq === Seq("Make", "Model", "Year", "Note")) + assert(secondRow(3) === """Reliable + | car""".stripMargin) + assert(thirdRow(3) === """Best + |selling + |car""".stripMargin) + + + } + +} From 996c13ca59b24de234d5065fc6efe945c6ec1bc2 Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 15 Jul 2014 00:35:18 -0700 Subject: [PATCH 14/20] Enable user to specify schema --- .../org/apache/spark/sql/SQLContext.scala | 10 +++- .../spark/sql/api/java/JavaSQLContext.scala | 34 ++++++----- .../org/apache/spark/sql/csv/CsvRDD.scala | 56 ++++++++++--------- 3 files changed, 57 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 84395657e2e6..add2e819e20d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -140,6 +140,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param path path to input file * @param delimiter Optional delimiter (default is comma) * @param quote Optional quote character or string (default is '"') + * @param schema optional StructType object to specify schema (field names and types). This will + * override field names if header is used * @param header Optional flag to indicate first line of each file is the header * (default is false) */ @@ -147,9 +149,10 @@ class SQLContext(@transient val sparkContext: SparkContext) path: String, delimiter: String = ",", quote: Char = '"', + schema: StructType = null, header: Boolean = false): SchemaRDD = { val csv = sparkContext.textFile(path) - csvRDD(csv, delimiter, quote, header) + csvRDD(csv, delimiter, quote, schema, header) } /** @@ -162,6 +165,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param csv input RDD * @param delimiter Optional delimiter (default is comma) * @param quote Optional quote character of strig (default is '"') + * @param schema optional StructType object to specify schema (field names and types). This will + * override field names if header is used * @param header Optional flag to indicate first line of each file is the hader * (default is false) */ @@ -169,8 +174,9 @@ class SQLContext(@transient val sparkContext: SparkContext) csv: RDD[String], delimiter: String = ",", quote: Char = '"', + schema: StructType = null, header: Boolean = false): SchemaRDD = { - new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, header)) + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, schema, header)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 9f22fe3f155a..82bef67c0a57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -124,25 +124,28 @@ class JavaSQLContext(val sqlContext: SQLContext) { new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) /** - * Loads a CSV file (according to RFC 4180) and returns the result as a [[JavaSchemaRDD]]. - * - * NOTE: If there are new line characters inside quoted fields this method may fail to - * parse correctly, because the two lines may be in different partitions. Use - * [[SQLContext#csvRDD]] to parse such files. - * - * @param path path to input file - * @param delimiter Optional delimiter (default is comma) - * @param quote Optional quote character or string (default is '"') - * @param header Optional flag to indicate first line of each file is the header - * (default is false) - */ + * Loads a CSV file (according to RFC 4180) and returns the result as a [[JavaSchemaRDD]]. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param schema optional StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ def csvFile( path: String, delimiter: String = ",", quote: Char = '"', + schema: StructType = null, header: Boolean = false): JavaSchemaRDD = { val csv = sqlContext.sparkContext.textFile(path) - csvRDD(csv, delimiter, quote, header) + csvRDD(csv, delimiter, quote, schema, header) } /** @@ -155,6 +158,8 @@ class JavaSQLContext(val sqlContext: SQLContext) { * @param csv input RDD * @param delimiter Optional delimiter (default is comma) * @param quote Optional quote character of strig (default is '"') + * @param schema optional StructType object to specify schema (field names and types). This will + * override field names if header is used * @param header Optional flag to indicate first line of each file is the hader * (default is false) */ @@ -162,8 +167,9 @@ class JavaSQLContext(val sqlContext: SQLContext) { csv: JavaRDD[String], delimiter: String = ",", quote: Char = '"', + schema: StructType = null, header: Boolean = false): JavaSchemaRDD = { - new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, header)) + new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, schema, header)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala index 9e8823d44151..6a53ea8bb341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.csv import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.sql.Logging @@ -38,26 +38,32 @@ private[sql] object CsvRDD extends Logging { csv: RDD[String], delimiter: String, quote: Char, + userSchema: StructType, useHeader: Boolean): LogicalPlan = { - // Constructing schema - // TODO: Infer types based on a sample and/or let user specify types/schema val firstLine = csv.first() - // Assuming first row is representative and using it to determine number of fields - val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() - val numFields = firstRow.length - logger.info(s"Parsing CSV with $numFields.") - val header = if (useHeader) { - logger.info(s"Using header line: $firstLine") - firstRow + val schema = if (userSchema == null) { + // Assume first row is representative and use it to determine number of fields + val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() + val header = if (useHeader) { + logger.info(s"Using header line: $firstLine") + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"V$index"} + } + // By default fields are assumed to be StringType + val schemaFields = header.map { fieldName => + StructField(fieldName, StringType, nullable = true) + } + StructType(schemaFields) } else { - firstRow.zipWithIndex.map { case (value, index) => s"V$index" } + userSchema } - val schemaFields = header.map { fieldName => - StructField(fieldName.asInstanceOf[String], StringType, nullable = true) - } + val numFields = schema.fields.length + logger.info(s"Parsing CSV with $numFields.") val row = new GenericMutableRow(numFields) + val projection = schemaCaster(asAttributes(schema)) val parsedCSV = csv.mapPartitions { iter => // When using header, any input line that equals firstLine is assumed to be header @@ -67,34 +73,30 @@ private[sql] object CsvRDD extends Logging { iter } val tokenIter = new CsvTokenizer(csvIter, delimiter, quote) - parseCSV(tokenIter, schemaFields, row) + parseCSV(tokenIter, schema.fields, projection, row) } - val schema = StructType(schemaFields) SparkLogicalPlan(ExistingRdd(asAttributes(schema), parsedCSV)) } - private def castToType(value: Any, dataType: DataType): Any = dataType match { - case StringType => value.asInstanceOf[String] - case BooleanType => value.asInstanceOf[Boolean] - case DoubleType => value.asInstanceOf[Double] - case FloatType => value.asInstanceOf[Float] - case IntegerType => value.asInstanceOf[Int] - case LongType => value.asInstanceOf[Long] - case ShortType => value.asInstanceOf[Short] - case _ => null + protected def schemaCaster(schema: Seq[AttributeReference]): MutableProjection = { + val startSchema = (1 to schema.length).toSeq.map( + index => new AttributeReference(s"V$index", StringType, nullable = true)()) + val casts = schema.zipWithIndex.map { case (ar, i) => Cast(startSchema(i), ar.dataType) } + new MutableProjection(casts, startSchema) } private def parseCSV( iter: Iterator[Array[String]], schemaFields: Seq[StructField], + projection: MutableProjection, row: GenericMutableRow): Iterator[Row] = { iter.map { tokens => schemaFields.zipWithIndex.foreach { case (StructField(name, dataType, _), index) => - row.update(index, castToType(tokens(index), dataType)) + row.update(index, tokens(index)) } - row + projection(row) } } From 5ceb1e706c18fd082bce5530ed2149b1d52c3dfa Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 15 Jul 2014 00:37:25 -0700 Subject: [PATCH 15/20] Limitting number of new lines inside quotes --- .../scala/org/apache/spark/sql/csv/CsvTokenizer.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala index 23b906925d0c..bb86aff5613e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvTokenizer.scala @@ -34,6 +34,8 @@ private[sql] class CsvTokenizer( private val BACKSLASH = '\\' private val NEWLINE = '\n' + private val MAX_QUOTED_LINES = 10 + private val delimLength = delimiter.length private def isDelimAt(inStr: String, index: Int): Boolean = { @@ -62,6 +64,7 @@ private[sql] class CsvTokenizer( var curPosition = 0 var startPosition = 0 var curChar: Char = '\0' + var quotedLines = 0 val leftOver = new StringBuilder() // Used to keep track of tokens that span multiple lines val tokens = new ArrayBuffer[String]() var line = inputIter.next() + '\n' @@ -88,17 +91,19 @@ private[sql] class CsvTokenizer( leftOver.append(line.substring(startPosition, curPosition)) tokens.append(stripQuotes(leftOver)) leftOver.clear() + quotedLines = 0 curPosition += delimLength startPosition = curPosition } else { curPosition += 1 } } else if (curChar == NEWLINE) { - if (curState == Quoted) { + if (curState == Quoted && quotedLines < MAX_QUOTED_LINES) { leftOver.append(line.substring(startPosition, curPosition + 1)) line = inputIter.next() + '\n' curPosition = 0 startPosition = 0 + quotedLines += 1 } else { if (curPosition == startPosition) { tokens.append(null) @@ -107,6 +112,7 @@ private[sql] class CsvTokenizer( tokens.append(stripQuotes(leftOver)) } leftOver.clear() + quotedLines = 0 curPosition += 1 } } else if (curChar == BACKSLASH) { From 7d89e5eb001236596919401f2fe28bda3377bac8 Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 15 Jul 2014 00:37:43 -0700 Subject: [PATCH 16/20] Updating tests --- .../src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala index ce6b0ee91c28..43359f34675a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/csv/CsvSuite.scala @@ -64,9 +64,7 @@ class CsvSuite extends QueryTest { } test("Custom quoted CSV with inner quotes") { - val csvSchemaRDD = csvRDD(salesCSVWithDoubleQuotes, - delimiter = "; ", - quote = "|") + val csvSchemaRDD = csvRDD(salesCSVWithDoubleQuotes, delimiter = "; ", quote = '|') csvSchemaRDD.registerAsTable("quotedSales") checkAnswer( @@ -81,7 +79,7 @@ class CsvSuite extends QueryTest { checkAnswer( sql("select distinct V2 from quotedSales where V2 like '%iPad%'"), - """iPad ||Power|| Adapter""" + """iPad |Power| Adapter""" ) } } From 143bfc149ce6ccdc48a5d3538bed3b142f12ddc6 Mon Sep 17 00:00:00 2001 From: Hossein Date: Tue, 15 Jul 2014 13:14:52 -0700 Subject: [PATCH 17/20] Fixed python test --- python/pyspark/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 5c68d80c783f..7c5a31ed78a5 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -199,7 +199,7 @@ def csvFile(self, path, delimiter = ",", quote = "\"", header = False): >>> import tempfile, shutil >>> csvFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) + >>> shutil.rmtree(csvFile) >>> ofn = open(csvFile, 'w') >>> for csvStr in csvStrings: ... print>>ofn, csvStr From 6a2487b18f3450b1271744ee4479c0499e4a4709 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 23 Jul 2014 17:51:16 -0700 Subject: [PATCH 18/20] Using option for schema --- python/pyspark/sql.py | 18 +++++----- .../org/apache/spark/sql/csv/CsvRDD.scala | 34 +++++++++---------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 8bee2316b24a..d3754c11af4a 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -203,9 +203,10 @@ def csvFile(self, path, delimiter = ",", quote = "\"", header = False): >>> for csvStr in csvStrings: ... print>>ofn, csvStr >>> ofn.close() - >>> csv = sqlCtx.csvFile(csvFile, delimiter = ", ", header = True) + >>> csv = sqlCtx.csvFile(csvFile, delimiter = ",", header = True) >>> sqlCtx.registerRDDAsTable(csv, "csvTable") - >>> csvRes = sqlCtx.sql("SELECT Year FROM csvTable limit 1") + >>> csvRes = sqlCtx.sql("SELECT Year FROM csvTable where Make = 'Ford'") + >>> csvRes.collect() == [{"Year": "1997"}] True """ jschema_rdd = self._ssql_ctx.csvFile(path, delimiter, quote, header) @@ -219,9 +220,10 @@ def csvRDD(self, rdd, delimiter = ",", quote = "\"", header = False): NOTE: If there are new line characters inside quoted fields, use wholeTextFile to read each file into a single partition. - >>> csvrdd = sqlCtx.csvRDD(csv, delimiter = ", ", header = True) + >>> csvrdd = sqlCtx.csvRDD(csv, delimiter = ",", header = True) >>> sqlCtx.registerRDDAsTable(csvrdd, "csvTable2") >>> csvRes = sqlCtx.sql("SELECT count(*) FROM csvTable2") + >>> csvRes.collect() == [{"c0": 3}] True """ def func(split, iterator): @@ -553,11 +555,11 @@ def _test(): ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) - csvStrings = ['Year,Make,Model,Description' - '"1997","Ford","E350",', - '2000,Mercury,"Cougar", "Really ""Good"" car"', - '2007,Honda,"Civic",'] - globs['csvString'] = csvStrings + csvStrings = ['Year, Make, Model, Description', + '"1997", "Ford", "E350", ', + '2000, Mercury, "Cougar", "Really ""Good"" car"', + '2007, Honda, "Civic", '] + globs['csvStrings'] = csvStrings globs['csv'] = sc.parallelize(csvStrings) globs['nestedRdd1'] = sc.parallelize([ {"f1": array('i', [1, 2]), "f2": {"row1": 1.0}}, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala index 6a53ea8bb341..9180af380e5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/csv/CsvRDD.scala @@ -38,26 +38,26 @@ private[sql] object CsvRDD extends Logging { csv: RDD[String], delimiter: String, quote: Char, - userSchema: StructType, + userSchema: Option[StructType], useHeader: Boolean): LogicalPlan = { val firstLine = csv.first() - val schema = if (userSchema == null) { - // Assume first row is representative and use it to determine number of fields - val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() - val header = if (useHeader) { - logger.info(s"Using header line: $firstLine") - firstRow - } else { - firstRow.zipWithIndex.map { case (value, index) => s"V$index"} - } - // By default fields are assumed to be StringType - val schemaFields = header.map { fieldName => - StructField(fieldName, StringType, nullable = true) - } - StructType(schemaFields) - } else { - userSchema + val schema = userSchema match { + case Some(userSupportedSchema) => userSupportedSchema + case None => + // Assume first row is representative and use it to determine number of fields + val firstRow = new CsvTokenizer(Seq(firstLine).iterator, delimiter, quote).next() + val header = if (useHeader) { + logger.info(s"Using header line: $firstLine") + firstRow + } else { + firstRow.zipWithIndex.map { case (value, index) => s"V$index"} + } + // By default fields are assumed to be StringType + val schemaFields = header.map { fieldName => + StructField(fieldName, StringType, nullable = true) + } + StructType(schemaFields) } val numFields = schema.fields.length From 829a5afaad6b9d2de7b6c14a9d6cc187e2930002 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 23 Jul 2014 17:51:38 -0700 Subject: [PATCH 19/20] Overloaded methods --- .../org/apache/spark/sql/SQLContext.scala | 61 ++++++++++++++-- .../spark/sql/api/java/JavaSQLContext.scala | 73 +++++++++++++++---- 2 files changed, 112 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index add2e819e20d..6d55e263cdc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -138,10 +138,34 @@ class SQLContext(@transient val sparkContext: SparkContext) * [[SQLContext#csvRDD]] to parse such files. * * @param path path to input file + * @param schema StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile( + path: String, + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): SchemaRDD = { + val csv = sparkContext.textFile(path) + csvRDD(csv, schema, delimiter, quote, header) + } + + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[SchemaRDD]]. + * It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file * @param delimiter Optional delimiter (default is comma) * @param quote Optional quote character or string (default is '"') - * @param schema optional StructType object to specify schema (field names and types). This will - * override field names if header is used * @param header Optional flag to indicate first line of each file is the header * (default is false) */ @@ -149,12 +173,12 @@ class SQLContext(@transient val sparkContext: SparkContext) path: String, delimiter: String = ",", quote: Char = '"', - schema: StructType = null, header: Boolean = false): SchemaRDD = { val csv = sparkContext.textFile(path) - csvRDD(csv, delimiter, quote, schema, header) + csvRDD(csv, delimiter, quote, header) } + /** * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a * [[SchemaRDD]]. @@ -163,10 +187,32 @@ class SQLContext(@transient val sparkContext: SparkContext) * [[SparkContext#wholeTextFiles]] to read each file into a single partition. * * @param csv input RDD + * @param schema StructType object to specify schema (field names and types). This will + * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD( + csv: RDD[String], + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): SchemaRDD = { + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, Some(schema), header)) + } + + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[SchemaRDD]]. It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[SparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD * @param delimiter Optional delimiter (default is comma) * @param quote Optional quote character of strig (default is '"') - * @param schema optional StructType object to specify schema (field names and types). This will - * override field names if header is used * @param header Optional flag to indicate first line of each file is the hader * (default is false) */ @@ -174,9 +220,8 @@ class SQLContext(@transient val sparkContext: SparkContext) csv: RDD[String], delimiter: String = ",", quote: Char = '"', - schema: StructType = null, header: Boolean = false): SchemaRDD = { - new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, schema, header)) + new SchemaRDD(this, CsvRDD.inferSchema(csv, delimiter, quote, None, header)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 82bef67c0a57..999d1e11b6d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -131,21 +131,44 @@ class JavaSQLContext(val sqlContext: SQLContext) { * [[SQLContext#csvRDD]] to parse such files. * * @param path path to input file + * @param schema StructType object to specify schema (field names and types). This will + * override field names if header is used * @param delimiter Optional delimiter (default is comma) * @param quote Optional quote character or string (default is '"') - * @param schema optional StructType object to specify schema (field names and types). This will - * override field names if header is used * @param header Optional flag to indicate first line of each file is the header * (default is false) */ def csvFile( path: String, - delimiter: String = ",", - quote: Char = '"', - schema: StructType = null, - header: Boolean = false): JavaSchemaRDD = { + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { val csv = sqlContext.sparkContext.textFile(path) - csvRDD(csv, delimiter, quote, schema, header) + csvRDD(csv, schema, delimiter, quote, header) + } + + /** + * Loads a CSV file (according to RFC 4180) and returns the result as a [[JavaSchemaRDD]]. + * It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields this method may fail to + * parse correctly, because the two lines may be in different partitions. Use + * [[SQLContext#csvRDD]] to parse such files. + * + * @param path path to input file + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character or string (default is '"') + * @param header Optional flag to indicate first line of each file is the header + * (default is false) + */ + def csvFile( + path: String, + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + val csv = sqlContext.sparkContext.textFile(path) + csvRDD(csv, delimiter, quote, header) } /** @@ -156,22 +179,44 @@ class JavaSQLContext(val sqlContext: SQLContext) { * [[JavaSparkContext#wholeTextFiles]] to read each file into a single partition. * * @param csv input RDD - * @param delimiter Optional delimiter (default is comma) - * @param quote Optional quote character of strig (default is '"') * @param schema optional StructType object to specify schema (field names and types). This will * override field names if header is used + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') * @param header Optional flag to indicate first line of each file is the hader * (default is false) */ def csvRDD( csv: JavaRDD[String], - delimiter: String = ",", - quote: Char = '"', - schema: StructType = null, - header: Boolean = false): JavaSchemaRDD = { - new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, schema, header)) + schema: StructType, + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, Some(schema), header)) } + /** + * Parses an RDD of String as a CSV (according to RFC 4180) and returns the result as a + * [[JavaSchemaRDD]]. It infers the schema based on the first record. + * + * NOTE: If there are new line characters inside quoted fields, use + * [[JavaSparkContext#wholeTextFiles]] to read each file into a single partition. + * + * @param csv input RDD + * @param delimiter Optional delimiter (default is comma) + * @param quote Optional quote character of strig (default is '"') + * @param header Optional flag to indicate first line of each file is the hader + * (default is false) + */ + def csvRDD( + csv: JavaRDD[String], + delimiter: String, + quote: Char, + header: Boolean): JavaSchemaRDD = { + new JavaSchemaRDD(sqlContext, CsvRDD.inferSchema(csv, delimiter, quote, None, header)) + } + + /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only * during the lifetime of this instance of SQLContext. From 11e6422ee25c8e61948b6d35375a6ccd1823dcc7 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 23 Jul 2014 20:01:43 -0700 Subject: [PATCH 20/20] Fixed python test --- python/pyspark/sql.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d3754c11af4a..045774a17ac1 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -203,11 +203,11 @@ def csvFile(self, path, delimiter = ",", quote = "\"", header = False): >>> for csvStr in csvStrings: ... print>>ofn, csvStr >>> ofn.close() - >>> csv = sqlCtx.csvFile(csvFile, delimiter = ",", header = True) + >>> csv = sqlCtx.csvFile(csvFile, delimiter = ", ", header = True) >>> sqlCtx.registerRDDAsTable(csv, "csvTable") - >>> csvRes = sqlCtx.sql("SELECT Year FROM csvTable where Make = 'Ford'") - >>> csvRes.collect() == [{"Year": "1997"}] - True + >>> csvRes = sqlCtx.sql("SELECT Year FROM csvTable WHERE Make = 'Ford'") + >>> csvRes.collect() + [{u'Year': u'1997'}] """ jschema_rdd = self._ssql_ctx.csvFile(path, delimiter, quote, header) return SchemaRDD(jschema_rdd, self) @@ -220,7 +220,7 @@ def csvRDD(self, rdd, delimiter = ",", quote = "\"", header = False): NOTE: If there are new line characters inside quoted fields, use wholeTextFile to read each file into a single partition. - >>> csvrdd = sqlCtx.csvRDD(csv, delimiter = ",", header = True) + >>> csvrdd = sqlCtx.csvRDD(csv, delimiter = ", ", header = True) >>> sqlCtx.registerRDDAsTable(csvrdd, "csvTable2") >>> csvRes = sqlCtx.sql("SELECT count(*) FROM csvTable2") >>> csvRes.collect() == [{"c0": 3}]