Skip to content

Commit ad71433

Browse files
committed
Handle more cases.
1 parent d774bfe commit ad71433

File tree

4 files changed

+72
-23
lines changed

4 files changed

+72
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ private[json] object InferSchema {
6161
StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))
6262
}
6363
}
64-
}.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType)
64+
}.treeAggregate[DataType](
65+
StructType(Seq()))(
66+
compatibleRootType(columnNameOfCorruptRecords),
67+
compatibleRootType(columnNameOfCorruptRecords))
6568

6669
canonicalizeType(rootType) match {
6770
case Some(st: StructType) => st
@@ -170,12 +173,38 @@ private[json] object InferSchema {
170173
case other => Some(other)
171174
}
172175

176+
private def withCorruptField(
177+
struct: StructType,
178+
columnNameOfCorruptRecords: String): StructType = {
179+
if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
180+
// If this given struct does not have a column used for corrupt records,
181+
// add this field.
182+
struct.add(columnNameOfCorruptRecords, StringType, nullable = true)
183+
} else {
184+
// Otherwise, just return this struct.
185+
struct
186+
}
187+
}
188+
173189
/**
174190
* Remove top-level ArrayType wrappers and merge the remaining schemas
175191
*/
176-
private def compatibleRootType: (DataType, DataType) => DataType = {
177-
case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2)
178-
case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2)
192+
private def compatibleRootType(
193+
columnNameOfCorruptRecords: String): (DataType, DataType) => DataType = {
194+
// Since we support array of json objects at the top level,
195+
// we need to check the element type and find the root level data type.
196+
case (ArrayType(ty1, _), ty2) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2)
197+
case (ty1, ArrayType(ty2, _)) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2)
198+
// If we see any other data type at the root level, we get records that cannot be
199+
// parsed. So, we use the struct as the data type and add the corrupt field to the schema.
200+
case (struct: StructType, NullType) => struct
201+
case (NullType, struct: StructType) => struct
202+
case (struct: StructType, o) if !o.isInstanceOf[StructType] =>
203+
withCorruptField(struct, columnNameOfCorruptRecords)
204+
case (o, struct: StructType) if !o.isInstanceOf[StructType] =>
205+
withCorruptField(struct, columnNameOfCorruptRecords)
206+
// If we get anything else, we call compatibleType.
207+
// Usually, when we reach here, ty1 and ty2 are two StructTypes.
179208
case (ty1, ty2) => compatibleType(ty1, ty2)
180209
}
181210

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.types.UTF8String
3232
import org.apache.spark.util.Utils
3333

34+
private[json] class SparkSQLJsonProcessingException(msg: String) extends Exception(msg)
35+
3436
object JacksonParser {
3537

3638
def parse(
@@ -110,7 +112,7 @@ object JacksonParser {
110112
lowerCaseValue.equals("-inf")) {
111113
value.toFloat
112114
} else {
113-
sys.error(s"Cannot parse $value as FloatType.")
115+
throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.")
114116
}
115117

116118
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
@@ -127,7 +129,7 @@ object JacksonParser {
127129
lowerCaseValue.equals("-inf")) {
128130
value.toDouble
129131
} else {
130-
sys.error(s"Cannot parse $value as DoubleType.")
132+
throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.")
131133
}
132134

133135
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) =>
@@ -174,7 +176,11 @@ object JacksonParser {
174176
convertField(factory, parser, udt.sqlType)
175177

176178
case (token, dataType) =>
177-
sys.error(s"Failed to parse a value for data type $dataType (current token: $token).")
179+
// We cannot parse this token based on the given data type. So, we throw a
180+
// SparkSQLJsonProcessingException and this exception will be caught by
181+
// parseJson method.
182+
throw new SparkSQLJsonProcessingException(
183+
s"Failed to parse a value for data type $dataType (current token: $token).")
178184
}
179185
}
180186

@@ -266,12 +272,15 @@ object JacksonParser {
266272
} else {
267273
array.toArray[InternalRow](schema)
268274
}
269-
case _ => failedRecord(record)
275+
case _ =>
276+
failedRecord(record)
270277
}
271278
}
272279
} catch {
273280
case _: JsonProcessingException =>
274281
failedRecord(record)
282+
case _: SparkSQLJsonProcessingException =>
283+
failedRecord(record)
275284
}
276285
}
277286
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,21 +1435,31 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
14351435
val schema = StructType(
14361436
StructField("_unparsed", StringType, true) ::
14371437
StructField("dummy", StringType, true) :: Nil)
1438-
val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords)
1439-
jsonDF.registerTempTable("jsonTable")
1440-
1441-
// In HiveContext, backticks should be used to access columns starting with a underscore.
1442-
checkAnswer(
1443-
sql(
1444-
"""
1445-
|SELECT dummy, _unparsed
1446-
|FROM jsonTable
1447-
""".stripMargin),
1448-
Row("test", null) ::
1449-
Row(null, """42""") ::
1450-
Row(null, """ ","ian":"test"}""") :: Nil
1451-
)
14521438

1439+
{
1440+
// We need to make sure we can infer the schema.
1441+
val jsonDF = sqlContext.read.json(additionalCorruptRecords)
1442+
assert(jsonDF.schema === schema)
1443+
}
1444+
1445+
{
1446+
val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords)
1447+
jsonDF.registerTempTable("jsonTable")
1448+
1449+
// In HiveContext, backticks should be used to access columns starting with a underscore.
1450+
checkAnswer(
1451+
sql(
1452+
"""
1453+
|SELECT dummy, _unparsed
1454+
|FROM jsonTable
1455+
""".stripMargin),
1456+
Row("test", null) ::
1457+
Row(null, """[1,2,3]""") ::
1458+
Row(null, """":"test", "a":1}""") ::
1459+
Row(null, """42""") ::
1460+
Row(null, """ ","ian":"test"}""") :: Nil
1461+
)
1462+
}
14531463
}
14541464
}
14551465
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ private[json] trait TestJsonData {
191191
def additionalCorruptRecords: RDD[String] =
192192
sqlContext.sparkContext.parallelize(
193193
"""{"dummy":"test"}""" ::
194+
"""[1,2,3]""" ::
195+
"""":"test", "a":1}""" ::
194196
"""42""" ::
195197
""" ","ian":"test"}""" :: Nil)
196198

@@ -203,7 +205,6 @@ private[json] trait TestJsonData {
203205
"""{"b": [{"c": {}}]}""" ::
204206
"""]""" :: Nil)
205207

206-
207208
lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
208209

209210
def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]())

0 commit comments

Comments
 (0)