From 66c9e80085261940a35d6354d7256da591e34af1 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Fri, 2 Dec 2016 15:22:17 -0800 Subject: [PATCH] [SPARK-10849][SQL} Add new jdbc datasource metadata property to allow users to specify database column type when creating table on write. --- .../datasources/jdbc/JdbcUtils.scala | 6 ++- .../spark/sql/jdbc/JDBCWriteSuite.scala | 40 ++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b13849475811..2e0ccc88d38a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -685,7 +685,11 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) schema.fields foreach { field => val name = dialect.quoteIdentifier(field.name) - val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition + val typ: String = if (field.metadata.contains("createTableColumnType")) { + field.metadata.getString("createTableColumnType") + } else { + getJdbcType(field.dataType, dialect).databaseTypeDefinition + } val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 354af29d4237..69d43bf22168 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -25,7 +25,8 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, Row, SaveMode} -import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -349,4 +350,41 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(e.contains("Invalid value `0` for parameter `numPartitions` in table writing " + "via JDBC. The minimum value is 1.")) } + + test("SPARK-10849: create table using user specified column type.") { + val data = Seq[Row]( + Row(1, "dave", "Boston", "electric cars"), + Row(2, "mary", "boston", "building planes") + ) + val nvarcharMd = + new MetadataBuilder().putString("createTableColumnType", "NVARCHAR(123)").build() + // Use H2 varchar_ignorecase type instead of TEXT to perform case-insensitive comparisions + val varcharIgnoreMd = + new MetadataBuilder().putString("createTableColumnType", "VARCHAR_IGNORECASE(20)").build() + val schema = StructType( + StructField("id", IntegerType) :: + StructField("name", StringType, metadata = nvarcharMd) :: + StructField("city", StringType, metadata = varcharIgnoreMd) :: + StructField("descr", StringType) :: + Nil) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) + assert(JdbcUtils.schemaString(df.schema, url1) == + s""""id" INTEGER , "name" NVARCHAR(123) , "city" VARCHAR_IGNORECASE(20) , "descr" TEXT """) + + // create the table with the user specified data types, and verify the data + df.write.jdbc(url1, "TEST.DBCOLTYPETEST", properties) + assert(spark.read.jdbc(url1, + """(select * from test.DBCOLTYPETEST where "city"='Boston')""", properties).count == 2) + } + + test("SPARK-10849: createTableColumnType property with invalid data type") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val invalidMd = + new MetadataBuilder().putString("createTableColumnType", "INVALID(123)").build() + val modifiedDf = df.withColumn("name", col("name"), invalidMd) + val msg = intercept[org.h2.jdbc.JdbcSQLException] { + modifiedDf.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.USERDBTYPETEST", properties) + }.getMessage() + assert(msg.contains("Unknown data type: \"INVALID\"")) + } }