diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2e75f0c8a1827..aa0132b068c94 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -494,7 +494,7 @@ def orc(self, path, mode=None, partitionBy=None): self._jwrite.orc(path) @since(1.4) - def jdbc(self, url, table, mode=None, properties=None): + def jdbc(self, url, table, mode=None, properties=None, columnMapping=None): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ @@ -511,13 +511,20 @@ def jdbc(self, url, table, mode=None, properties=None): :param properties: JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least a "user" and "password" property should be included. + :param columnMapping: optional column name mapping from DF field names to + JDBC table column names. """ if properties is None: properties = dict() jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) - self._jwrite.mode(mode).jdbc(url, table, jprop) + if columnMapping is None: + columnMapping = dict() + jcolumnMapping = JavaClass("java.util.HashMap", self._sqlContext._sc._gateway._gateway_client)() + for k in columnMapping: + jcolumnMapping.put(k, columnMapping[k]) + self._jwrite.mode(mode).jdbc(url, table, jprop, jcolumnMapping) def _test(): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb78224..764dbfd56167f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -265,10 +265,22 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. + * @param columnMapping Maps DataFrame column names to target table column names. + * This parameter can be omitted if the target table has/will be + * created in this method and therefore the target table structure + * matches the DF structure. + * This parameter is stongly recommended, if target table already + * exists and has been created outside of this method. + * If omitted, the SQL insert statement will not include column names, + * which means that the field ordering of the DataFrame must match + * the target table column ordering. * * @since 1.4.0 */ - def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + def jdbc(url: String, + table: String, + connectionProperties: Properties, + columnMapping: scala.collection.Map[String, String]): Unit = { val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) @@ -303,7 +315,33 @@ final class DataFrameWriter private[sql](df: DataFrame) { conn.close() } - JdbcUtils.saveTable(df, url, table, props) + JdbcUtils.saveTable(df, url, table, props, columnMapping) + } + + /** + * (java-friendly) version of + * [[DataFrameWriter.jdbc(String,String,Properties,scala.collection.Map[String,String]):]] + */ + def jdbc(url: String, + table: String, + connectionProperties: Properties, + columnMapping: java.util.Map[String, String]): Unit = { + // Convert java Map into scala Map + var sColumnMapping: scala.collection.Map[String, String] = null + if (columnMapping!=null) { + sColumnMapping = columnMapping.asScala + } + jdbc(url, table, connectionProperties, sColumnMapping) + } + + /** + * Three parameter version of + * [[DataFrameWriter.jdbc(String,String,Properties,scala.collection.Map[String,String]):]] + */ + def jdbc(url: String, + table: String, + connectionProperties: Properties): Unit = { + jdbc(url, table, connectionProperties, null.asInstanceOf[Map[String, String]]) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 252f1cfd5d9c5..1166989926370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -61,16 +61,16 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. + * If a columnMapping is provided, it will be used to translate RDD + * column names into table column names. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString()) + def insertStatement(conn: Connection, + dialect: JdbcDialect, + table: String, + rddSchema: StructType, + columnMapping: scala.collection.Map[String, String]): PreparedStatement = { + val sql = dialect.getInsertStatement(table, rddSchema, columnMapping) + conn.prepareStatement(sql) } /** @@ -122,6 +122,7 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], + columnMapping: scala.collection.Map[String, String] = null, batchSize: Int, dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() @@ -139,7 +140,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema, columnMapping) try { var rowCount = 0 while (iterator.hasNext) { @@ -234,7 +235,8 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties = new Properties(), + columnMapping: scala.collection.Map[String, String] = null) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType @@ -245,7 +247,8 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, + columnMapping, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 13db141f27db6..64c09ade0617d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -108,6 +108,30 @@ abstract class JdbcDialect extends Serializable { def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } + /** + * Get the SQL statement that should be used to insert new records into the table. + * Dialects can override this method to return a statement that works best in a particular + * database. + * @param table The name of the table. + * @param rddSchema The schema of DataFrame to be inserted + * @param columnMapping An optional mapping from DataFrame field names to database column + * names + * @return The SQL statement to use for inserting into the table. + */ + def getInsertStatement(table: String, + rddSchema: StructType, + columnMapping: scala.collection.Map[String, String] = null): String = { + if (columnMapping == null) { + rddSchema.fields.map(_ => "?") + .mkString(s"INSERT INTO $table VALUES (", ", ", " ) ") + } else { + rddSchema.fields.map( + field => columnMapping.getOrElse(field.name, field.name) + ).mkString(s"INSERT INTO $table ( ", ", ", " ) " ) + + rddSchema.fields.map(field => "?").mkString("VALUES ( ", ", ", " )" ) + } + } + } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133b..b263ccc9ea07d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -92,6 +92,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).collect()(0).length) + } + + test("Basic CREATE with columnMapping") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val columnMapping = Map("name" -> "name", "id" -> "id") + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties, columnMapping) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count) assert( 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) }