Skip to content

Commit c400848

Browse files
maropugatorsmile
authored andcommitted
[SPARK-20009][SQL] Support DDL strings for defining schema in functions.from_json
## What changes were proposed in this pull request? This pr added `StructType.fromDDL` to convert a DDL format string into `StructType` for defining schemas in `functions.from_json`. ## How was this patch tested? Added tests in `JsonFunctionsSuite`. Author: Takeshi Yamamuro <[email protected]> Closes #17406 from maropu/SPARK-20009.
1 parent 142f6d1 commit c400848

File tree

5 files changed

+90
-25
lines changed

5 files changed

+90
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ object StructType extends AbstractDataType {
417417
}
418418
}
419419

420+
/**
421+
* Creates StructType for a given DDL-formatted string, which is a comma separated list of field
422+
* definitions, e.g., a INT, b STRING.
423+
*/
424+
def fromDDL(ddl: String): StructType = CatalystSqlParser.parseTableSchema(ddl)
425+
420426
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
421427

422428
def apply(fields: java.util.List[StructField]): StructType = {

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -169,30 +169,72 @@ class DataTypeSuite extends SparkFunSuite {
169169
assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
170170
}
171171

172-
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
173-
test(s"JSON - $dataType") {
172+
def checkDataTypeFromJson(dataType: DataType): Unit = {
173+
test(s"from Json - $dataType") {
174174
assert(DataType.fromJson(dataType.json) === dataType)
175175
}
176176
}
177177

178-
checkDataTypeJsonRepr(NullType)
179-
checkDataTypeJsonRepr(BooleanType)
180-
checkDataTypeJsonRepr(ByteType)
181-
checkDataTypeJsonRepr(ShortType)
182-
checkDataTypeJsonRepr(IntegerType)
183-
checkDataTypeJsonRepr(LongType)
184-
checkDataTypeJsonRepr(FloatType)
185-
checkDataTypeJsonRepr(DoubleType)
186-
checkDataTypeJsonRepr(DecimalType(10, 5))
187-
checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT)
188-
checkDataTypeJsonRepr(DateType)
189-
checkDataTypeJsonRepr(TimestampType)
190-
checkDataTypeJsonRepr(StringType)
191-
checkDataTypeJsonRepr(BinaryType)
192-
checkDataTypeJsonRepr(ArrayType(DoubleType, true))
193-
checkDataTypeJsonRepr(ArrayType(StringType, false))
194-
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
195-
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
178+
def checkDataTypeFromDDL(dataType: DataType): Unit = {
179+
test(s"from DDL - $dataType") {
180+
val parsed = StructType.fromDDL(s"a ${dataType.sql}")
181+
val expected = new StructType().add("a", dataType)
182+
assert(parsed.sameType(expected))
183+
}
184+
}
185+
186+
checkDataTypeFromJson(NullType)
187+
188+
checkDataTypeFromJson(BooleanType)
189+
checkDataTypeFromDDL(BooleanType)
190+
191+
checkDataTypeFromJson(ByteType)
192+
checkDataTypeFromDDL(ByteType)
193+
194+
checkDataTypeFromJson(ShortType)
195+
checkDataTypeFromDDL(ShortType)
196+
197+
checkDataTypeFromJson(IntegerType)
198+
checkDataTypeFromDDL(IntegerType)
199+
200+
checkDataTypeFromJson(LongType)
201+
checkDataTypeFromDDL(LongType)
202+
203+
checkDataTypeFromJson(FloatType)
204+
checkDataTypeFromDDL(FloatType)
205+
206+
checkDataTypeFromJson(DoubleType)
207+
checkDataTypeFromDDL(DoubleType)
208+
209+
checkDataTypeFromJson(DecimalType(10, 5))
210+
checkDataTypeFromDDL(DecimalType(10, 5))
211+
212+
checkDataTypeFromJson(DecimalType.SYSTEM_DEFAULT)
213+
checkDataTypeFromDDL(DecimalType.SYSTEM_DEFAULT)
214+
215+
checkDataTypeFromJson(DateType)
216+
checkDataTypeFromDDL(DateType)
217+
218+
checkDataTypeFromJson(TimestampType)
219+
checkDataTypeFromDDL(TimestampType)
220+
221+
checkDataTypeFromJson(StringType)
222+
checkDataTypeFromDDL(StringType)
223+
224+
checkDataTypeFromJson(BinaryType)
225+
checkDataTypeFromDDL(BinaryType)
226+
227+
checkDataTypeFromJson(ArrayType(DoubleType, true))
228+
checkDataTypeFromDDL(ArrayType(DoubleType, true))
229+
230+
checkDataTypeFromJson(ArrayType(StringType, false))
231+
checkDataTypeFromDDL(ArrayType(StringType, false))
232+
233+
checkDataTypeFromJson(MapType(IntegerType, StringType, true))
234+
checkDataTypeFromDDL(MapType(IntegerType, StringType, true))
235+
236+
checkDataTypeFromJson(MapType(IntegerType, ArrayType(DoubleType), false))
237+
checkDataTypeFromDDL(MapType(IntegerType, ArrayType(DoubleType), false))
196238

197239
val metadata = new MetadataBuilder()
198240
.putString("name", "age")
@@ -201,7 +243,8 @@ class DataTypeSuite extends SparkFunSuite {
201243
StructField("a", IntegerType, nullable = true),
202244
StructField("b", ArrayType(DoubleType), nullable = false),
203245
StructField("c", DoubleType, nullable = false, metadata)))
204-
checkDataTypeJsonRepr(structType)
246+
checkDataTypeFromJson(structType)
247+
checkDataTypeFromDDL(structType)
205248

206249
def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
207250
test(s"Check the default size of $dataType") {

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
2121
import scala.language.implicitConversions
2222
import scala.reflect.runtime.universe.{typeTag, TypeTag}
2323
import scala.util.Try
24+
import scala.util.control.NonFatal
2425

2526
import org.apache.spark.annotation.{Experimental, InterfaceStability}
2627
import org.apache.spark.sql.catalyst.ScalaReflection
@@ -3055,13 +3056,21 @@ object functions {
30553056
* with the specified schema. Returns `null`, in the case of an unparseable string.
30563057
*
30573058
* @param e a string column containing JSON data.
3058-
* @param schema the schema to use when parsing the json string as a json string
3059+
* @param schema the schema to use when parsing the json string as a json string. In Spark 2.1,
3060+
* the user-provided schema has to be in JSON format. Since Spark 2.2, the DDL
3061+
* format is also supported for the schema.
30593062
*
30603063
* @group collection_funcs
30613064
* @since 2.1.0
30623065
*/
3063-
def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
3064-
from_json(e, DataType.fromJson(schema), options)
3066+
def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = {
3067+
val dataType = try {
3068+
DataType.fromJson(schema)
3069+
} catch {
3070+
case NonFatal(_) => StructType.fromDDL(schema)
3071+
}
3072+
from_json(e, dataType, options)
3073+
}
30653074

30663075
/**
30673076
* (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s

sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
156156
Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
157157
}
158158

159+
test("from_json uses DDL strings for defining a schema") {
160+
val df = Seq("""{"a": 1, "b": "haa"}""").toDS()
161+
checkAnswer(
162+
df.select(from_json($"value", "a INT, b STRING", new java.util.HashMap[String, String]())),
163+
Row(Row(1, "haa")) :: Nil)
164+
}
165+
159166
test("to_json - struct") {
160167
val df = Seq(Tuple1(Tuple1(1))).toDF("a")
161168

sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.hadoop.conf.Configuration
2121
import org.apache.hadoop.fs.{FileStatus, Path}
2222
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
2323

24-
import org.apache.spark.sql.{sources, Row, SparkSession}
24+
import org.apache.spark.sql.{sources, SparkSession}
2525
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
2626
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal}
2727
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection

0 commit comments

Comments
 (0)