Skip to content
11 changes: 9 additions & 2 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def orc(self, path, mode=None, partitionBy=None):
self._jwrite.orc(path)

@since(1.4)
def jdbc(self, url, table, mode=None, properties=None):
def jdbc(self, url, table, mode=None, properties=None, columnMapping=None):
"""Saves the content of the :class:`DataFrame` to a external database table via JDBC.

.. note:: Don't create too many partitions in parallel on a large cluster;\
Expand All @@ -511,13 +511,20 @@ def jdbc(self, url, table, mode=None, properties=None):
:param properties: JDBC database connection arguments, a list of
arbitrary string tag/value. Normally at least a
"user" and "password" property should be included.
:param columnMapping: optional column name mapping from DF field names to
JDBC table column names.
"""
if properties is None:
properties = dict()
jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
for k in properties:
jprop.setProperty(k, properties[k])
self._jwrite.mode(mode).jdbc(url, table, jprop)
if columnMapping is None:
columnMapping = dict()
jcolumnMapping = JavaClass("java.util.HashMap", self._sqlContext._sc._gateway._gateway_client)()
for k in columnMapping:
jcolumnMapping.put(k, columnMapping[k])
self._jwrite.mode(mode).jdbc(url, table, jprop, jcolumnMapping)


def _test():
Expand Down
42 changes: 40 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}

/**
* (scala-specific)
* Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the
* table already exists in the external database, behavior of this function depends on the
* save mode, specified by the `mode` function (default to throwing an exception).
Expand All @@ -265,10 +266,22 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included.
* @param columnMapping Maps DataFrame column names to target table column names.
* This parameter can be omitted if the target table has/will be
* created in this method and therefore the target table structure
* matches the DF structure.
* This parameter is stongly recommended, if target table already
* exists and has been created outside of this method.
* If omitted, the SQL insert statement will not include column names,
* which means that the field ordering of the DataFrame must match
* the target table column ordering.
*
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
def jdbc(url: String,
table: String,
connectionProperties: Properties,
columnMapping: scala.collection.immutable.Map[String, String]): Unit = {
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
Expand Down Expand Up @@ -303,7 +316,32 @@ final class DataFrameWriter private[sql](df: DataFrame) {
conn.close()
}

JdbcUtils.saveTable(df, url, table, props)
JdbcUtils.saveTable(df, url, table, props, columnMapping)
}

/**
* (java-specific) version of jdbc method
*/
def jdbc(url: String,
table: String,
connectionProperties: Properties,
columnMapping: java.util.Map[String, String]): Unit = {
// Convert java Map into immutable scala Map
var sColumnMapping: scala.collection.immutable.Map[String, String] = null
if (columnMapping!=null) {
sColumnMapping = collection.immutable.Map(columnMapping.asScala.toList: _*)
}
jdbc( url, table, connectionProperties, sColumnMapping )
}

/**
* legacy three parameter version of jdbc method
*/
def jdbc(url: String,
table: String,
connectionProperties: Properties): Unit = {
val columnMapping: scala.collection.immutable.Map[String, String] = null
jdbc( url, table, connectionProperties, columnMapping )
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ object JdbcUtils extends Logging {

/**
* Returns a PreparedStatement that inserts a row into table via conn.
* If a columnMapping is provided, it will be used to translate rdd
* column names into table column names.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
var fieldsLeft = rddSchema.fields.length
while (fieldsLeft > 0) {
sql.append("?")
if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
fieldsLeft = fieldsLeft - 1
}
conn.prepareStatement(sql.toString())
def insertStatement(conn: Connection,
dialect: JdbcDialect,
table: String,
rddSchema: StructType,
columnMapping: Map[String, String]): PreparedStatement = {
val sql = dialect.getInsertStatement(table, rddSchema, columnMapping)
conn.prepareStatement(sql)
}

/**
Expand Down Expand Up @@ -122,6 +122,7 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
columnMapping: Map[String, String] = null,
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
Expand All @@ -139,7 +140,7 @@ object JdbcUtils extends Logging {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
}
val stmt = insertStatement(conn, table, rddSchema)
val stmt = insertStatement(conn, dialect, table, rddSchema, columnMapping)
try {
var rowCount = 0
while (iterator.hasNext) {
Expand Down Expand Up @@ -234,7 +235,8 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
properties: Properties = new Properties()) {
properties: Properties = new Properties(),
columnMapping: Map[String, String] = null) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
Expand All @@ -245,7 +247,8 @@ object JdbcUtils extends Logging {
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
savePartition(getConnection, table, iterator, rddSchema, nullTypes,
columnMapping, batchSize, dialect)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc

import java.sql.Types

import org.apache.spark.sql.types._


private case object CassandraDialect extends JdbcDialect {

override def canHandle(url: String): Boolean =
url.startsWith("jdbc:datadirect:cassandra") ||
url.startsWith("jdbc:weblogic:cassandra")

override def getInsertStatement(table: String, rddSchema: StructType): String = {
val sql = new StringBuilder(s"INSERT INTO $table ( ")
var fieldsLeft = rddSchema.fields.length
var i = 0
// Build list of column names
while (fieldsLeft > 0) {
sql.append(rddSchema.fields(i).name)
if (fieldsLeft > 1) sql.append(", ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits: braces and newline for the if; use a += 1 instead of a = a + 1 in the two lines below; extra blank line above near the imports.

You should probably also make this more idiomatic and compact. For example this while loop collapses to rddSchema.fields.map(_.name).mkString(", "), I believe. Similarly for the final while. And then the entire method doesn't need to manually build it up with StringBuilder. This is probably a couple lines of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sean: I am working on your suggested changes. Looks like the code can
be collapsed into one or two lines.

Meanwhile I have found that inserting into an Oracle table having more
columns than columns in the dataframe results in

java.sql.BatchUpdateException: ORA-00947: not enough values if there are
any unmapped columns.

This does not matter, as long as the table matches exactly the
dataframe. But as soon as someone wants to insert into an existing table
with more columns than the dataframe has, this is a problem.

So it may indeed be better to include the suggested change for other
technologies as well.

The key question I see is: Is it okay to rely on the dataframe column
names matching the target table column names?

If so, do you suggest changing the default behaviour to include column
names for all dialects?

Does Spark automated tests have coverage for different databases? /
Would any regression be caught prior to merge?

btw,
re squashing commits: I will try, but for now I need to better
understand how all of this works in GitHub.

On 01.12.2015 13:17, Sean Owen wrote:

In
sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala
#10066 (comment):

+private case object CassandraDialect extends JdbcDialect {
+

  • override def canHandle(url: String): Boolean =
  • url.startsWith("jdbc:datadirect:cassandra") ||
  • url.startsWith("jdbc:weblogic:cassandra")
  • override def getInsertStatement(table: String, rddSchema: StructType): String = {
  • val sql = new StringBuilder(s"INSERT INTO $table ( ")
  • var fieldsLeft = rddSchema.fields.length
  • var i = 0
  • // Build list of column names
  • while (fieldsLeft > 0) {
  •  sql.append(rddSchema.fields(i).name)
    
  •  if (fieldsLeft > 1) sql.append(", ")
    

Nits: braces and newline for the |if|; use |a += 1| instead of |a = a

  • 1| in the two lines below; extra blank line above near the imports.

You should probably also make this more idiomatic and compact. For
example this |while| loop collapses to
|rddSchema.fields.map(_.name).mkString(", ")|, I believe. Similarly
for the final |while|. And then the entire method doesn't need to
manually build it up with |StringBuilder|. This is probably a couple
lines of code.


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/10066/files#r46270863.

Oracle http://www.oracle.com
Christian Kurz | Consulting Member of Technical Staff
Phone: +49 228 30899431 tel:+49%20228%2030899431 | Mobile: +49 170
2964124 tel:+49%20170%202964124
Oracle Product Development

ORACLE Deutschland B.V. & Co. KG | Hamborner Str. 51 | 40472 Düsseldorf

ORACLE Deutschland B.V. & Co. KG
Hauptverwaltung: Riesstr. 25, D-80992 München
Registergericht: Amtsgericht München, HRA 95603

Komplementärin: ORACLE Deutschland Verwaltung B.V.
Hertogswetering 163/167, 3543 AS Utrecht, Niederlande
Handelsregister der Handelskammer Midden-Niederlande, Nr. 30143697
Geschäftsführer: Alexander van der Ven, Astrid Kepper, Val Maher

Green Oracle http://www.oracle.com/commitment Oracle is committed to
developing practices and products that help protect the environment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're just saying that inserting a DataFrame of m columns into a table of n > m columns doesn't work, right? Yes without column name mappings, I expect this to fail anytime m != n, for any database. Right now this assumes m = n implicitly.

You're right that adding names requires a mapping from data frame column names to DB column names. Hm, I wonder if this needs an optional Map allowing for overrides.

I don't think the regression tests cover all databases, no. I also don't think this can be specific to Oracle anyway.

My workflow for squashing N of the last commits is:

  • git rebase -i HEAD~N
  • Change all but the first "pick" to "squash" in the editor and save
  • Edit the commit message down to just 1 logical message and save
  • git push --force origin [your branch]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is exactly the problem encountered.

What if we drop the new Dialect and add a new signature on
DataFrameWriter instead (new columnMapping param):

def jdbc(url: String, table: String, connectionProperties: Properties,
columnMapping: Map<String,String>): Unit

The old signature then continues using the column-name-free INSERT
syntax, but for any advanced use-cases (or technologies, which do not
support column-name-free syntax) the new API can be used.

This ensures full backwards compatibility for all technologies

If this is the way to go, I'd better start a new PR?

My preference would still to keep the refactoring of moving generation
of INSERT statement into the Dialect (instead of in JDBCUtils). Does
this make sense?

On 02.12.2015 11:46, Sean Owen wrote:

In
sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala
#10066 (comment):

+private case object CassandraDialect extends JdbcDialect {
+

  • override def canHandle(url: String): Boolean =
  • url.startsWith("jdbc:datadirect:cassandra") ||
  • url.startsWith("jdbc:weblogic:cassandra")
  • override def getInsertStatement(table: String, rddSchema: StructType): String = {
  • val sql = new StringBuilder(s"INSERT INTO $table ( ")
  • var fieldsLeft = rddSchema.fields.length
  • var i = 0
  • // Build list of column names
  • while (fieldsLeft > 0) {
  •  sql.append(rddSchema.fields(i).name)
    
  •  if (fieldsLeft > 1) sql.append(", ")
    

You're just saying that inserting a DataFrame of m columns into a
table of n > m columns doesn't work, right? Yes without column name
mappings, I expect this to fail anytime m != n, for any database.
Right now this assumes m = n implicitly.

You're right that adding names requires a mapping from data frame
column names to DB column names. Hm, I wonder if this needs an
optional |Map| allowing for overrides.

I don't think the regression tests cover all databases, no. I also
don't think this can be specific to Oracle anyway.

My workflow for squashing N of the last commits is:

  • |git rebase -i HEAD~N|
  • Change all but the first "pick" to "squash" in the editor and save
  • Edit the commit message down to just 1 logical message and save
  • |git push --force origin [your branch]|


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/10066/files#r46399876.

Oracle http://www.oracle.com
Christian Kurz | Consulting Member of Technical Staff
Phone: +49 228 30899431 tel:+49%20228%2030899431 | Mobile: +49 170
2964124 tel:+49%20170%202964124
Oracle Product Development

ORACLE Deutschland B.V. & Co. KG | Hamborner Str. 51 | 40472 Düsseldorf

ORACLE Deutschland B.V. & Co. KG
Hauptverwaltung: Riesstr. 25, D-80992 München
Registergericht: Amtsgericht München, HRA 95603

Komplementärin: ORACLE Deutschland Verwaltung B.V.
Hertogswetering 163/167, 3543 AS Utrecht, Niederlande
Handelsregister der Handelskammer Midden-Niederlande, Nr. 30143697
Geschäftsführer: Alexander van der Ven, Astrid Kepper, Val Maher

Green Oracle http://www.oracle.com/commitment Oracle is committed to
developing practices and products that help protect the environment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin are you the best person to ask about overloading DataFrameWriter.jdbc()?

Interesting question about maintaining the current behavior when no column name mapping is specified. In a way it still seems suboptimal to allow this behavior. What if there are the same number of columns, and all are the same type, but the ordering is different? you'd silently insert the wrong data in the wrong column.

Although specifying the DataFrame-to-table column name mapping can be optional (or, the caller can override only the names they want to) I think the SQL statement should be explicit. It does mean that someone who has a DataFrame with differently-named columns somehow might now encounter an exception, but I wonder if that's actually the right thing to enforce going forward. If it doesn't match up by name, don't proceed.

The API changes will take some care to make sure it's unintrusive and backwards compatible.

I suspect it doesn't do much harm to keep the insert statement logic in JdbcDialects though I imagine this behavior, whatever we decide, will be the right thing for all dialects, so it can be a default implementation there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just go ahead and add the new overload of jdbc().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will provide PR shortly

On 10.12.2015 23:44, Sean Owen wrote:

In
sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala
#10066 (comment):

+private case object CassandraDialect extends JdbcDialect {
+

  • override def canHandle(url: String): Boolean =
  • url.startsWith("jdbc:datadirect:cassandra") ||
  • url.startsWith("jdbc:weblogic:cassandra")
  • override def getInsertStatement(table: String, rddSchema: StructType): String = {
  • val sql = new StringBuilder(s"INSERT INTO $table ( ")
  • var fieldsLeft = rddSchema.fields.length
  • var i = 0
  • // Build list of column names
  • while (fieldsLeft > 0) {
  •  sql.append(rddSchema.fields(i).name)
    
  •  if (fieldsLeft > 1) sql.append(", ")
    

I would just go ahead and add the new overload of |jdbc()|.


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/10066/files#r47298240.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #10312. - Please let me know your thoughts. - Thanks, Christian

On 10.12.2015 23:44, Sean Owen wrote:

In
sql/core/src/main/scala/org/apache/spark/sql/jdbc/CassandraDialect.scala
#10066 (comment):

+private case object CassandraDialect extends JdbcDialect {
+

  • override def canHandle(url: String): Boolean =
  • url.startsWith("jdbc:datadirect:cassandra") ||
  • url.startsWith("jdbc:weblogic:cassandra")
  • override def getInsertStatement(table: String, rddSchema: StructType): String = {
  • val sql = new StringBuilder(s"INSERT INTO $table ( ")
  • var fieldsLeft = rddSchema.fields.length
  • var i = 0
  • // Build list of column names
  • while (fieldsLeft > 0) {
  •  sql.append(rddSchema.fields(i).name)
    
  •  if (fieldsLeft > 1) sql.append(", ")
    

I would just go ahead and add the new overload of |jdbc()|.


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/10066/files#r47298240.

fieldsLeft = fieldsLeft - 1
i = i + 1
}
sql.append(" ) VALUES ( ")
// Build values clause
fieldsLeft = rddSchema.fields.length
while (fieldsLeft > 0) {
sql.append("?")
if (fieldsLeft > 1) sql.append(", ")
fieldsLeft = fieldsLeft - 1
}
sql.append(" ) ")
return sql.toString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,34 @@ abstract class JdbcDialect extends Serializable {
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}

/**
* Get the SQL statement that should be used to insert new records into the table.
* Dialects can override this method to return a statement that works best in a particular
* database.
* @param table The name of the table.
* @param rddSchema The schema of DataFrame to be inserted
* @param columnMapping An optional mapping from DataFrame field names to database column
* names
* @return The SQL statement to use for inserting into the table.
*/
def getInsertStatement(table: String,
rddSchema: StructType,
columnMapping: Map[String, String] = null): String = {
if (columnMapping == null) {
return rddSchema.fields.map(field => "?")
.mkString( s"INSERT INTO $table VALUES (", ", ", " ) ")
} else {
return rddSchema.fields.map(
field => columnMapping.get(field.name) match {
case Some(name) => name
case None => s"<JdbcDialect.getInsertStatement: No entry " +
s"found in columnMapping for field '${field.name}'>"
}
).mkString( s"INSERT INTO $table ( ", ", ", " ) " ) +
rddSchema.fields.map(field => "?").mkString( "VALUES ( ", ", ", " )" )
}
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}

test("Basic CREATE with columnMapping") {
val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)

val columnMapping = Map("name" -> "name", "id" -> "id")
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties, columnMapping)
assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
assert(
2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}

test("CREATE with overwrite") {
val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
Expand Down