Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.types._


private[csv] object CSVInferSchema {

/**
Expand All @@ -48,7 +47,11 @@ private[csv] object CSVInferSchema {
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)

val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true)
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}

StructType(structFields)
Expand All @@ -65,12 +68,8 @@ private[csv] object CSVInferSchema {
}

def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
val tpe = findTightestCommonType(a, b).getOrElse(StringType)
tpe match {
case _: NullType => StringType
case other => other
}
first.zipAll(second, NullType, NullType).map { case (a, b) =>
findTightestCommonType(a, b).getOrElse(NullType)
}
}

Expand Down Expand Up @@ -140,6 +139,8 @@ private[csv] object CSVInferSchema {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
case (StringType, t2) => Some(StringType)
case (t1, StringType) => Some(StringType)

// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
Expand All @@ -150,7 +151,6 @@ private[csv] object CSVInferSchema {
}
}


private[csv] object CSVTypeCast {

/**
Expand Down
5 changes: 5 additions & 0 deletions sql/core/src/test/resources/simple_sparse.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
A,B,C,D
1,,,
,1,,
,,1,
,,,1
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
}

test("Merging Nulltypes should yeild Nulltype.") {
val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
assert(mergedNullTypes.deep == Array(NullType).deep)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val simpleSparseFile = "simple_sparse.csv"

private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
Expand Down Expand Up @@ -224,7 +225,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema.fieldNames.size === 1)
}


test("DDL test with empty file") {
sqlContext.sql(s"""
|CREATE TEMPORARY TABLE carsTable
Expand Down Expand Up @@ -387,4 +387,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(carsCopy, withHeader = true)
}
}

test("Schema inference correctly identifies the datatype when data is sparse.") {
val df = sqlContext.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(simpleSparseFile))

assert(
df.schema.fields.map(field => field.dataType).deep ==
Array(IntegerType, IntegerType, IntegerType, IntegerType).deep)
}
}