diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 065c8572b06a2..f745f00f4a881 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -96,8 +96,9 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val columns = rddSchema.fields.map(_.name).mkString(",") + def insertStatement(dialect: JdbcDialect, conn: Connection, table: String, rddSchema: StructType) + : PreparedStatement = { + val columns = rddSchema.fields.map(f => quoteColumnName(dialect, f.name)).mkString(",") val placeholders = rddSchema.fields.map(_ => "?").mkString(",") val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" conn.prepareStatement(sql) @@ -169,7 +170,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(dialect, conn, table, rddSchema) try { var rowCount = 0 while (iterator.hasNext) { @@ -245,6 +246,16 @@ object JdbcUtils extends Logging { Array[Byte]().iterator } + /** + * The utility to add quote to the column name based on its dialect + * @param dialect the JDBC dialect + * @param columnName the input column name + * @return the quoted column name + */ + private def quoteColumnName(dialect: JdbcDialect, columnName: String): String = { + dialect.quoteIdentifier(columnName) + } + /** * Compute the schema string for this RDD. */ @@ -252,7 +263,8 @@ object JdbcUtils extends Logging { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => - val name = field.name + + val name = quoteColumnName(dialect, field.name) val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133b..4147fb415f2a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -52,7 +52,11 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { conn1.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() conn1.prepareStatement("drop table if exists test.people1").executeUpdate() conn1.prepareStatement( - "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + "create table test.people1 (name TEXT(32) NOT NULL, `the id` INTEGER NOT NULL)") + .executeUpdate() + conn1.prepareStatement( + "create table test.orders (`order` TEXT(32) NOT NULL, `order id` INTEGER NOT NULL)") + .executeUpdate() conn1.commit() sql( @@ -68,6 +72,13 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + + sql( + s""" + |CREATE TEMPORARY TABLE ORDERS + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.ORDERS', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) } after { @@ -151,4 +162,13 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("SPARK-14460: Insert into table with column containing space") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + df.write.insertInto("PEOPLE1") + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + + df.write.insertInto("ORDERS") + assert(2 === sqlContext.read.jdbc(url1, "TEST.ORDERS", properties).count) + } }