From 3b3c1b73fe8dda6190d10ac567d33aead5beb337 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 25 Jan 2016 22:50:54 -0800 Subject: [PATCH] A few minor tweaks to CSV reader. --- .../datasources/csv/CSVInferSchema.scala | 21 +++++++------------ .../datasources/csv/CSVRelation.scala | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 0aa4539e6051..ace8cd7ad864 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -30,16 +30,15 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.types._ -private[sql] object CSVInferSchema { +private[csv] object CSVInferSchema { /** * Similar to the JSON schema inference * 1. Infer type of each row * 2. Merge row types to find common type * 3. Replace any null types with string type - * TODO(hossein): Can we reuse JSON schema inference? [SPARK-12670] */ - def apply( + def infer( tokenRdd: RDD[Array[String]], header: Array[String], nullValue: String = ""): StructType = { @@ -65,10 +64,7 @@ private[sql] object CSVInferSchema { rowSoFar } - private[csv] def mergeRowTypes( - first: Array[DataType], - second: Array[DataType]): Array[DataType] = { - + def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { first.zipAll(second, NullType, NullType).map { case ((a, b)) => val tpe = findTightestCommonType(a, b).getOrElse(StringType) tpe match { @@ -82,8 +78,7 @@ private[sql] object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - private[csv] def inferField( - typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { + def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { if (field == null || field.isEmpty || field == nullValue) { typeSoFar } else { @@ -155,7 +150,8 @@ private[sql] object CSVInferSchema { } } -object CSVTypeCast { + +private[csv] object CSVTypeCast { /** * Casts given string datum to specified type. @@ -167,7 +163,7 @@ object CSVTypeCast { * @param datum string value * @param castType SparkSQL type */ - private[csv] def castTo( + def castTo( datum: String, castType: DataType, nullable: Boolean = true, @@ -201,10 +197,9 @@ object CSVTypeCast { * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one * character. - * */ @throws[IllegalArgumentException] - private[csv] def toChar(str: String): Char = { + def toChar(str: String): Char = { if (str.charAt(0) == '\\') { str.charAt(1) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 5959f7cc5051..dc449fea956f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -139,7 +139,7 @@ private[csv] class CSVRelation( val parsedRdd = tokenRdd(header, paths) if (params.inferSchemaFlag) { - CSVInferSchema(parsedRdd, header, params.nullValue) + CSVInferSchema.infer(parsedRdd, header, params.nullValue) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName =>