diff --git a/build.sbt b/build.sbt index 1f511bb..d4f2fc0 100755 --- a/build.sbt +++ b/build.sbt @@ -60,3 +60,7 @@ sparkComponents += "sql" libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.1" % "test" libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test" + +// Fork to help tests of methods using reflection. +// See https://issues.apache.org/jira/browse/SPARK-5281. +fork in Test := true diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 3d41de0..4cb6abe 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -19,6 +19,7 @@ import org.apache.commons.csv.CSVFormat import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.types.StructType package object csv { @@ -31,14 +32,16 @@ package object csv { delimiter: Char = ',', quote: Char = '"', escape: Char = '\\', - mode: String = "PERMISSIVE") = { + mode: String = "PERMISSIVE", + schema: Option[StructType] = None) = { val csvRelation = CsvRelation( location = filePath, useHeader = useHeader, delimiter = delimiter, quote = quote, escape = escape, - parseMode = mode)(sqlContext) + parseMode = mode, + userSchema = schema.orNull)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } @@ -53,7 +56,7 @@ package object csv { sqlContext.baseRelationToDataFrame(csvRelation) } } - + implicit class CsvSchemaRDD(dataFrame: DataFrame) { /** diff --git a/src/main/scala/com/databricks/spark/csv/rdd/package.scala b/src/main/scala/com/databricks/spark/csv/rdd/package.scala new file mode 100644 index 0000000..b4f4498 --- /dev/null +++ b/src/main/scala/com/databricks/spark/csv/rdd/package.scala @@ -0,0 +1,77 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.spark.csv + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.StructType + +package object rdd { + implicit class CsvContextRDD(sqlContext: SQLContext) { + def csvFileToRDD[T: scala.reflect.runtime.universe.TypeTag : scala.reflect.ClassTag]( + filePath: String, + useHeader: Boolean = true, + delimiter: Char = ',', + quote: Char = '"', + escape: Char = '\\', + mode: String = "DROPMALFORMED"): RDD[T] = { + + if (mode == util.ParseModes.PERMISSIVE_MODE) { + throw new IllegalArgumentException(s"permissive mode is invalid for this method") + } + + val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] + if (schema.exists { structField => !structField.dataType.isPrimitive }) { + throw new IllegalArgumentException(s"type must be a case class with only primitive fields") + } + + val csvContext = new CsvContext(sqlContext) + val df = csvContext.csvFile(filePath, useHeader, delimiter, quote, escape, mode, Some(schema)) + df.mapPartitions[T] { iter => + val rowConverter = RowConverter[T]() + iter.map { row => rowConverter.convert(row) } + } + } + + def tsvFileToRDD[T: scala.reflect.runtime.universe.TypeTag : scala.reflect.ClassTag]( + filePath: String, + useHeader: Boolean = true, + mode: String = "DROPMALFORMED"): RDD[T] = { + csvFileToRDD[T](filePath, useHeader, delimiter = '\t', quote = '"', escape = '\\', mode) + } + } + + case class RowConverter[T]()(implicit ct: scala.reflect.ClassTag[T]) { + // http://docs.scala-lang.org/overviews/reflection/environment-universes-mirrors.html + + // For Scala 2.10, because we're initializing the runtime universe, this is not thread-safe. + // http://docs.scala-lang.org/overviews/reflection/thread-safety.html + val ru = scala.reflect.runtime.universe + + val mirror = ru.runtimeMirror(getClass.getClassLoader) + val classSymbol = mirror.classSymbol(ct.runtimeClass) + val classMirror = mirror.reflectClass(classSymbol) + val constructorSymbol = classSymbol.toType.declaration(ru.nme.CONSTRUCTOR).asMethod + val constructorMirror = classMirror.reflectConstructor(constructorSymbol) + + def convert(row: Row): T = { + val args = row.toSeq + require(constructorSymbol.paramss.head.size == args.size) + constructorMirror.apply(args: _*).asInstanceOf[T] + } + } +} diff --git a/src/test/resources/cars-with-typed-columns-without-headers.csv b/src/test/resources/cars-with-typed-columns-without-headers.csv new file mode 100644 index 0000000..20230ea --- /dev/null +++ b/src/test/resources/cars-with-typed-columns-without-headers.csv @@ -0,0 +1,4 @@ +"2012","Tesla","S","No comment",1,350000.00 + +1997,Ford,E350,"Go get one now they are going fast",3,25000.00 +2015,Chevy,Volt \ No newline at end of file diff --git a/src/test/resources/cars-with-typed-columns.csv b/src/test/resources/cars-with-typed-columns.csv new file mode 100644 index 0000000..6099267 --- /dev/null +++ b/src/test/resources/cars-with-typed-columns.csv @@ -0,0 +1,5 @@ +year,make,model,comment,stocked,price +"2012","Tesla","S","No comment",1,350000.00 + +1997,Ford,E350,"Go get one now they are going fast",3,25000.00 +2015,Chevy,Volt \ No newline at end of file diff --git a/src/test/resources/cars-with-typed-columns.tsv b/src/test/resources/cars-with-typed-columns.tsv new file mode 100644 index 0000000..df824a8 --- /dev/null +++ b/src/test/resources/cars-with-typed-columns.tsv @@ -0,0 +1,5 @@ +year make model comment stocked price +"2012" "Tesla" "S" "No comment" 1 350000.00 + +1997 Ford E350 "Go get one now they are going fast" 3 25000.00 +2015 Chevy Volt \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/rdd/CsvToRDDSuite.scala b/src/test/scala/com/databricks/spark/csv/rdd/CsvToRDDSuite.scala new file mode 100644 index 0000000..736e3bc --- /dev/null +++ b/src/test/scala/com/databricks/spark/csv/rdd/CsvToRDDSuite.scala @@ -0,0 +1,65 @@ +package com.databricks.spark.csv.rdd + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.test._ +import org.apache.spark.sql.types._ +import org.scalatest.FunSuite +import org.scalatest.Matchers + +// Because this suite tests reflection, the test only works in SBT if the config uses forking. +// There is no workaround for Eclipse. +// See https://issues.apache.org/jira/browse/SPARK-5281. +class CsvToRDDSuite extends FunSuite with Matchers { + import CsvToRDDSuite._ + import TestSQLContext._ + + val carsFile = "src/test/resources/cars-with-typed-columns.csv" + val carsFileTsv = "src/test/resources/cars-with-typed-columns.tsv" + val carsFileWithoutHeaders = "src/test/resources/cars-with-typed-columns-without-headers.csv" + + test("DSL for RDD with DROPMALFORMED parsing mode") { + val rdd = TestSQLContext.csvFileToRDD[Car](carsFile) + rdd.collect() should contain theSameElementsAs Seq( + Car(2012, "Tesla", "S", "No comment", 1, 350000.00), + Car(1997, "Ford", "E350", "Go get one now they are going fast", 3, 25000.00)) + } + + test("DSL for RDD with DROPMALFORMED parsing mode, TSV") { + val rdd = TestSQLContext.tsvFileToRDD[Car](carsFileTsv) + rdd.collect() should contain theSameElementsAs Seq( + Car(2012, "Tesla", "S", "No comment", 1, 350000.00), + Car(1997, "Ford", "E350", "Go get one now they are going fast", 3, 25000.00)) + } + + test("DSL for RDD with DROPMALFORMED parsing mode, without headers") { + val rdd = TestSQLContext.csvFileToRDD[Car](carsFileWithoutHeaders, useHeader = false) + rdd.collect() should contain theSameElementsAs Seq( + Car(2012, "Tesla", "S", "No comment", 1, 350000.00), + Car(1997, "Ford", "E350", "Go get one now they are going fast", 3, 25000.00)) + } + + test("DSL for RDD with FAILFAST parsing mode") { + intercept[org.apache.spark.SparkException] { + val rdd = TestSQLContext.csvFileToRDD[Car](carsFile, mode = "FAILFAST") + println(rdd.collect()) + } + } + + test("DSL for RDD with PERMISSIVE parsing mode") { + intercept[IllegalArgumentException] { + TestSQLContext.csvFileToRDD[Car](carsFile, mode = "PERMISSIVE") + } + } + + test("DSL for RDD with invalid type argument") { + intercept[IllegalArgumentException] { + TestSQLContext.csvFileToRDD[CarWithNonPrimitive](carsFile) + } + } +} + +object CsvToRDDSuite { + case class Car(year: Int, make: String, model: String, comment: String, stocked: Int, price: Double) + case class CarWithNonPrimitive(year: Int, makeAndModel: MakeAndModel, comment: String, stocked: Int, price: Double) + case class MakeAndModel(make: String, model: String) +}