Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
).executeUpdate()
conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " +
"'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate()

conn.prepareStatement("CREATE TABLE `escaped names` (`key` BIGINT, `long description` TEXT)").executeUpdate()
conn.prepareStatement("INSERT INTO `escaped names` VALUES (123456789012345,'fred')").executeUpdate()
}

test("Basic test") {
Expand Down Expand Up @@ -150,4 +153,112 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
}

test("Basic test with table escaped names") {
val df = sqlContext.read.jdbc(jdbcUrl, "`escaped names`", new Properties)
val rows = df.collect()
assert(rows.length == 2)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 2)
assert(types(0).equals("class java.lang.Long"))
assert(types(1).equals("class java.lang.String"))

df.write.mode("overwrite")
df.write.jdbc(jdbcUrl, "`escaped names`", new Properties)

val df1 = sqlContext.read.jdbc(jdbcUrl, "`escaped names`", new Properties)
val rows1 = df1.collect()
assert(rows1.length == 2)
val types1 = rows1(0).toSeq.map(x => x.getClass.toString)
assert(types1.length == 2)
assert(types1(0).equals("class java.lang.Long"))
assert(types1(1).equals("class java.lang.String"))
}

test("Write test with SaveMode set to overwrite") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a table or tables to the dataPreparation function that have column names that are key words? For example:

create table escaped names (key BIGINT, long description TEXT);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pjfanning Added test as suggested
For example:
create table escaped names (key BIGINT, long description TEXT);

val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 9)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Long"))
assert(types(2).equals("class java.lang.Integer"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Integer"))
assert(types(5).equals("class java.lang.Long"))
assert(types(6).equals("class java.math.BigDecimal"))
assert(types(7).equals("class java.lang.Double"))
assert(types(8).equals("class java.lang.Double"))
assert(rows(0).getBoolean(0) == false)
assert(rows(0).getLong(1) == 0x225)
assert(rows(0).getInt(2) == 17)
assert(rows(0).getInt(3) == 77777)
assert(rows(0).getInt(4) == 123456789)
assert(rows(0).getLong(5) == 123456789012345L)
val bd = new BigDecimal("123456789012345.12345678901234500000")
assert(rows(0).getAs[BigDecimal](6).equals(bd))
assert(rows(0).getDouble(7) == 42.75)
assert(rows(0).getDouble(8) == 1.0000000000000002)
df.write.mode("overwrite")
df.write.jdbc(jdbcUrl, "numbers", new Properties)

val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
val rows1 = df1.collect()
assert(rows.length == 1)
val types1 = rows1(0).toSeq.map(x => x.getClass.toString)
assert(types1.length == 9)
assert(types1(0).equals("class java.lang.Boolean"))
assert(types1(1).equals("class java.lang.Long"))
assert(types1(2).equals("class java.lang.Integer"))
assert(types1(3).equals("class java.lang.Integer"))
assert(types1(4).equals("class java.lang.Integer"))
assert(types1(5).equals("class java.lang.Long"))
assert(types1(6).equals("class java.math.BigDecimal"))
assert(types1(7).equals("class java.lang.Double"))
assert(types1(8).equals("class java.lang.Double"))
assert(rows1(0).getBoolean(0) == false)
assert(rows1(0).getLong(1) == 0x225)
assert(rows1(0).getInt(2) == 17)
assert(rows1(0).getInt(3) == 77777)
assert(rows1(0).getInt(4) == 123456789)
assert(rows1(0).getLong(5) == 123456789012345L)
val bd1 = new BigDecimal("123456789012345.12345678901234500000")
assert(rows1(0).getAs[BigDecimal](6).equals(bd1))
assert(rows1(0).getDouble(7) == 42.75)
assert(rows1(0).getDouble(8) == 1.0000000000000002)
}

test("Write test with SaveMode set to append") {
val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 9)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Long"))
assert(types(2).equals("class java.lang.Integer"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Integer"))
assert(types(5).equals("class java.lang.Long"))
assert(types(6).equals("class java.math.BigDecimal"))
assert(types(7).equals("class java.lang.Double"))
assert(types(8).equals("class java.lang.Double"))
assert(rows(0).getBoolean(0) == false)
assert(rows(0).getLong(1) == 0x225)
assert(rows(0).getInt(2) == 17)
assert(rows(0).getInt(3) == 77777)
assert(rows(0).getInt(4) == 123456789)
assert(rows(0).getLong(5) == 123456789012345L)
val bd = new BigDecimal("123456789012345.12345678901234500000")
assert(rows(0).getAs[BigDecimal](6).equals(bd))
assert(rows(0).getDouble(7) == 42.75)
assert(rows(0).getDouble(8) == 1.0000000000000002)
df.write.mode("append")
df.write.jdbc(jdbcUrl, "numbers", new Properties)

val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
val rows1 = df1.collectAsList()
assert(rows1.size() == 2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnectionFactory(url, props)()
var tableName = JdbcUtils.schemaQualifiedTableName(table, url)

try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
var tableExists = JdbcUtils.tableExists(conn, url, tableName)

if (mode == SaveMode.Ignore && tableExists) {
return
Expand All @@ -354,14 +355,14 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}

if (mode == SaveMode.Overwrite && tableExists) {
JdbcUtils.dropTable(conn, table)
JdbcUtils.dropTable(conn, tableName)
tableExists = false
}

// Create the table if the table didn't exist.
if (!tableExists) {
val schema = JdbcUtils.schemaString(df, url)
val sql = s"CREATE TABLE $table ($schema)"
val sql = s"CREATE TABLE $tableName ($schema)"
val statement = conn.createStatement
try {
statement.executeUpdate(sql)
Expand All @@ -373,7 +374,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
conn.close()
}

JdbcUtils.saveTable(df, url, table, props)
JdbcUtils.saveTable(df, url, tableName, props)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,21 @@ object JdbcUtils extends Logging {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
val name = dialect.quoteIdentifier(field.name)
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
if (sb.length < 2) "" else sb.substring(2)
}

/**
* Parse the table name string for this RDD.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing, since the variable table already sounds like it would be the table name.

Can you call it schemaQualifiedTableName or something more explicit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted as suggested

*/
def schemaQualifiedTableName(table: String, url: String): String = {
val dialect = JdbcDialects.get(url)
dialect.schemaQualifiedTableName(table)
}

/**
* Saves the RDD to the database in a single transaction.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ abstract class JdbcDialect extends Serializable {
s""""$colName""""
}

/**
* Returns table name
*/
def schemaQualifiedTableName(tableName: String): String = {
s"$tableName"
}

/**
* Get the SQL query that should be used to find if the given table exists. Dialects can
* override this method to return a query that works best in a particular database.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,34 @@ private case object MySQLDialect extends JdbcDialect {
override def quoteIdentifier(colName: String): String = {
s"`$colName`"
}

/**
* Process table name in case of containing special characters like dot seperating database name
* followed by table name (eg "some database"."some-table-name") or
* in case it contains characters that require quotes (e.g. space).
*/
override def schemaQualifiedTableName(tableName: String): String = {
//Removing quotes so that we can add them correctly.
val tableNameWithoutQuotes = tableName.replace("\"", "").replace("\'", "")

//If block for addressing the case of . (eg "some database"."some-table-name")
if (tableNameWithoutQuotes.contains(".")) {
val tableNameList = tableNameWithoutQuotes.split('.')
tableNameList.foldLeft("") { (leftStr, rightStr) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woah, this is kinda dense and confusing. Can you please rewrite this in a simpler fashion and add comments? This is just my drive-by first impression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you add unit tests for this function? Not end-to-end tests, but a test which exercises this method in isolation, similar to the ones that I have in spark-redshift?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not a fan of this foldLeft. Can you please write this in a more clear way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping. Any updates here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address it these week sometime

if (!"".equals(rightStr.trim())) {
if ("".equals(leftStr.trim())) {
leftStr + s"`$rightStr`"
} else {
leftStr + "." + s"`$rightStr`"
}
} else {
leftStr
}
}
} else {
s"`$tableNameWithoutQuotes`"
}
}

override def getTableExistsQuery(table: String): String = {
s"SELECT 1 FROM $table LIMIT 1"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.jdbc

import java.sql.Types
import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext


class MySQLDialectSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {

val testMySQLDialect = JdbcDialects.get("jdbc:mysql://localhost:3306/mysql?user=root&password=rootpass");

test("testing schemaQualifiedTableName") {
assert( testMySQLDialect.schemaQualifiedTableName("some shcema.some table").equals("`some shcema`.`some table`"))
assert( testMySQLDialect.schemaQualifiedTableName("some shcema.some table space.some table").equals("`some shcema`.`some table space`.`some table`"))
assert( testMySQLDialect.schemaQualifiedTableName("some table").equals("`some shcema`"))
}
}