Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ sparkComponents += "sql"
libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.1" % "test"

libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test"

// Fork to help tests of methods using reflection.
// See https://issues.apache.org/jira/browse/SPARK-5281.
fork in Test := true
9 changes: 6 additions & 3 deletions src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.apache.commons.csv.CSVFormat
import org.apache.hadoop.io.compress.CompressionCodec

import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types.StructType

package object csv {

Expand All @@ -31,14 +32,16 @@ package object csv {
delimiter: Char = ',',
quote: Char = '"',
escape: Char = '\\',
mode: String = "PERMISSIVE") = {
mode: String = "PERMISSIVE",
schema: Option[StructType] = None) = {
val csvRelation = CsvRelation(
location = filePath,
useHeader = useHeader,
delimiter = delimiter,
quote = quote,
escape = escape,
parseMode = mode)(sqlContext)
parseMode = mode,
userSchema = schema.orNull)(sqlContext)
sqlContext.baseRelationToDataFrame(csvRelation)
}

Expand All @@ -53,7 +56,7 @@ package object csv {
sqlContext.baseRelationToDataFrame(csvRelation)
}
}

implicit class CsvSchemaRDD(dataFrame: DataFrame) {

/**
Expand Down
77 changes: 77 additions & 0 deletions src/main/scala/com/databricks/spark/csv/rdd/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.csv

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.StructType

package object rdd {
implicit class CsvContextRDD(sqlContext: SQLContext) {
def csvFileToRDD[T: scala.reflect.runtime.universe.TypeTag : scala.reflect.ClassTag](
filePath: String,
useHeader: Boolean = true,
delimiter: Char = ',',
quote: Char = '"',
escape: Char = '\\',
mode: String = "DROPMALFORMED"): RDD[T] = {

if (mode == util.ParseModes.PERMISSIVE_MODE) {
throw new IllegalArgumentException(s"permissive mode is invalid for this method")
}

val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
if (schema.exists { structField => !structField.dataType.isPrimitive }) {
throw new IllegalArgumentException(s"type must be a case class with only primitive fields")
}

val csvContext = new CsvContext(sqlContext)
val df = csvContext.csvFile(filePath, useHeader, delimiter, quote, escape, mode, Some(schema))
df.mapPartitions[T] { iter =>
val rowConverter = RowConverter[T]()
iter.map { row => rowConverter.convert(row) }
}
}

def tsvFileToRDD[T: scala.reflect.runtime.universe.TypeTag : scala.reflect.ClassTag](
filePath: String,
useHeader: Boolean = true,
mode: String = "DROPMALFORMED"): RDD[T] = {
csvFileToRDD[T](filePath, useHeader, delimiter = '\t', quote = '"', escape = '\\', mode)
}
}

case class RowConverter[T]()(implicit ct: scala.reflect.ClassTag[T]) {
// http://docs.scala-lang.org/overviews/reflection/environment-universes-mirrors.html

// For Scala 2.10, because we're initializing the runtime universe, this is not thread-safe.
// http://docs.scala-lang.org/overviews/reflection/thread-safety.html
val ru = scala.reflect.runtime.universe

val mirror = ru.runtimeMirror(getClass.getClassLoader)
val classSymbol = mirror.classSymbol(ct.runtimeClass)
val classMirror = mirror.reflectClass(classSymbol)
val constructorSymbol = classSymbol.toType.declaration(ru.nme.CONSTRUCTOR).asMethod
val constructorMirror = classMirror.reflectConstructor(constructorSymbol)

def convert(row: Row): T = {
val args = row.toSeq
require(constructorSymbol.paramss.head.size == args.size)
constructorMirror.apply(args: _*).asInstanceOf[T]
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"2012","Tesla","S","No comment",1,350000.00

1997,Ford,E350,"Go get one now they are going fast",3,25000.00
2015,Chevy,Volt
5 changes: 5 additions & 0 deletions src/test/resources/cars-with-typed-columns.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
year,make,model,comment,stocked,price
"2012","Tesla","S","No comment",1,350000.00

1997,Ford,E350,"Go get one now they are going fast",3,25000.00
2015,Chevy,Volt
5 changes: 5 additions & 0 deletions src/test/resources/cars-with-typed-columns.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
year make model comment stocked price
"2012" "Tesla" "S" "No comment" 1 350000.00

1997 Ford E350 "Go get one now they are going fast" 3 25000.00
2015 Chevy Volt
65 changes: 65 additions & 0 deletions src/test/scala/com/databricks/spark/csv/rdd/CsvToRDDSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.databricks.spark.csv.rdd

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.Matchers

// Because this suite tests reflection, the test only works in SBT if the config uses forking.
// There is no workaround for Eclipse.
// See https://issues.apache.org/jira/browse/SPARK-5281.
class CsvToRDDSuite extends FunSuite with Matchers {
import CsvToRDDSuite._
import TestSQLContext._

val carsFile = "src/test/resources/cars-with-typed-columns.csv"
val carsFileTsv = "src/test/resources/cars-with-typed-columns.tsv"
val carsFileWithoutHeaders = "src/test/resources/cars-with-typed-columns-without-headers.csv"

test("DSL for RDD with DROPMALFORMED parsing mode") {
val rdd = TestSQLContext.csvFileToRDD[Car](carsFile)
rdd.collect() should contain theSameElementsAs Seq(
Car(2012, "Tesla", "S", "No comment", 1, 350000.00),
Car(1997, "Ford", "E350", "Go get one now they are going fast", 3, 25000.00))
}

test("DSL for RDD with DROPMALFORMED parsing mode, TSV") {
val rdd = TestSQLContext.tsvFileToRDD[Car](carsFileTsv)
rdd.collect() should contain theSameElementsAs Seq(
Car(2012, "Tesla", "S", "No comment", 1, 350000.00),
Car(1997, "Ford", "E350", "Go get one now they are going fast", 3, 25000.00))
}

test("DSL for RDD with DROPMALFORMED parsing mode, without headers") {
val rdd = TestSQLContext.csvFileToRDD[Car](carsFileWithoutHeaders, useHeader = false)
rdd.collect() should contain theSameElementsAs Seq(
Car(2012, "Tesla", "S", "No comment", 1, 350000.00),
Car(1997, "Ford", "E350", "Go get one now they are going fast", 3, 25000.00))
}

test("DSL for RDD with FAILFAST parsing mode") {
intercept[org.apache.spark.SparkException] {
val rdd = TestSQLContext.csvFileToRDD[Car](carsFile, mode = "FAILFAST")
println(rdd.collect())
}
}

test("DSL for RDD with PERMISSIVE parsing mode") {
intercept[IllegalArgumentException] {
TestSQLContext.csvFileToRDD[Car](carsFile, mode = "PERMISSIVE")
}
}

test("DSL for RDD with invalid type argument") {
intercept[IllegalArgumentException] {
TestSQLContext.csvFileToRDD[CarWithNonPrimitive](carsFile)
}
}
}

object CsvToRDDSuite {
case class Car(year: Int, make: String, model: String, comment: String, stocked: Int, price: Double)
case class CarWithNonPrimitive(year: Int, makeAndModel: MakeAndModel, comment: String, stocked: Int, price: Double)
case class MakeAndModel(make: String, model: String)
}