diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 349007789f63..fe15e4709dab 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -111,6 +111,18 @@
mockito-core
test
+
+ com.spotify
+ docker-client
+ 2.7.5
+ test
+
+
+ guava
+ com.google.guava
+
+
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
new file mode 100644
index 000000000000..41c6f7ade3c6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Connection
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.MutableList
+
+import com.spotify.docker.client.messages.ContainerConfig
+import com.spotify.docker.client._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
+import org.scalatest.BeforeAndAfterAll
+
+abstract class DatabaseOnDocker {
+ /**
+ * The docker image to be pulled
+ */
+ def imageName: String
+
+ /**
+ * A Seq of environment variables in the form of VAR=value
+ */
+ def env: Seq[String]
+
+ /**
+ * jdbcUrl should be a lazy val or a function since `ip` it relies on is only available after
+ * the docker container starts
+ */
+ def jdbcUrl: String
+
+ private val docker: DockerClient = DockerClientFactory.get()
+ private var containerId: String = null
+
+ lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
+
+ def start(): Unit = {
+ while (true) {
+ try {
+ val config = ContainerConfig.builder()
+ .image(imageName).env(env.asJava)
+ .build()
+ containerId = docker.createContainer(config).id
+ docker.startContainer(containerId)
+ return
+ } catch {
+ case e: ImageNotFoundException => retry(5)(docker.pull(imageName))
+ }
+ }
+ }
+
+ private def retry[T](n: Int)(fn: => T): T = {
+ try {
+ fn
+ } catch {
+ case e if n > 1 =>
+ retry(n - 1)(fn)
+ }
+ }
+
+ def close(): Unit = {
+ docker.killContainer(containerId)
+ docker.removeContainer(containerId)
+ DockerClientFactory.close(docker)
+ }
+}
+
+abstract class DatabaseIntegrationSuite extends SparkFunSuite
+ with BeforeAndAfterAll with SharedSQLContext {
+
+ def db: DatabaseOnDocker
+
+ def waitForDatabase(ip: String, maxMillis: Long) {
+ val before = System.currentTimeMillis()
+ var lastException: java.sql.SQLException = null
+ while (true) {
+ if (System.currentTimeMillis() > before + maxMillis) {
+ throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException)
+ }
+ try {
+ val conn = java.sql.DriverManager.getConnection(db.jdbcUrl)
+ conn.close()
+ return
+ } catch {
+ case e: java.sql.SQLException =>
+ lastException = e
+ java.lang.Thread.sleep(250)
+ }
+ }
+ }
+
+ def setupDatabase(ip: String): Unit = {
+ val conn: Connection = java.sql.DriverManager.getConnection(db.jdbcUrl)
+ try {
+ dataPreparation(conn)
+ } finally {
+ conn.close()
+ }
+ }
+
+ /**
+ * Prepare databases and tables for testing
+ */
+ def dataPreparation(connection: Connection)
+
+ override def beforeAll() {
+ super.beforeAll()
+ db.start()
+ waitForDatabase(db.ip, 60000)
+ setupDatabase(db.ip)
+ }
+
+ override def afterAll() {
+ try {
+ db.close()
+ } finally {
+ super.afterAll()
+ }
+ }
+}
+
+/**
+ * A factory and morgue for DockerClient objects. In the DockerClient we use,
+ * calling close() closes the desired DockerClient but also renders all other
+ * DockerClients inoperable. This is inconvenient if we have more than one
+ * open, such as during tests.
+ */
+object DockerClientFactory {
+ var numClients: Int = 0
+ val zombies = new MutableList[DockerClient]()
+
+ def get(): DockerClient = {
+ this.synchronized {
+ numClients = numClients + 1
+ DefaultDockerClient.fromEnv.build()
+ }
+ }
+
+ def close(dc: DockerClient) {
+ this.synchronized {
+ numClients = numClients - 1
+ zombies += dc
+ if (numClients == 0) {
+ zombies.foreach(_.close())
+ zombies.clear()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
new file mode 100644
index 000000000000..b295894c8097
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import java.sql.{Connection, Date, Timestamp}
+import java.util.Properties
+
+class MySQLIntegrationSuite extends DatabaseIntegrationSuite {
+ val db = new DatabaseOnDocker {
+ val imageName = "mysql:latest"
+ val env = Seq("MYSQL_ROOT_PASSWORD=rootpass")
+ lazy val jdbcUrl = s"jdbc:mysql://$ip:3306/mysql?user=root&password=rootpass"
+ }
+
+ override def dataPreparation(conn: Connection) {
+ conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
+ conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate()
+ conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate()
+ conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), "
+ + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
+ + "dbl DOUBLE)").executeUpdate()
+ conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', "
+ + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
+ + "42.75, 1.0000000000000002)").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
+ + "yr YEAR)").executeUpdate()
+ conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', "
+ + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate()
+
+ // TODO: Test locale conversion for strings.
+ conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
+ + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
+ ).executeUpdate()
+ conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " +
+ "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate()
+ }
+
+ test("Basic test") {
+ val df = sqlContext.read.jdbc(db.jdbcUrl, "tbl", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 2)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 2)
+ assert(types(0).equals("class java.lang.Integer"))
+ assert(types(1).equals("class java.lang.String"))
+ }
+
+ test("Numeric types") {
+ val df = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 9)
+ assert(types(0).equals("class java.lang.Boolean"))
+ assert(types(1).equals("class java.lang.Long"))
+ assert(types(2).equals("class java.lang.Integer"))
+ assert(types(3).equals("class java.lang.Integer"))
+ assert(types(4).equals("class java.lang.Integer"))
+ assert(types(5).equals("class java.lang.Long"))
+ assert(types(6).equals("class java.math.BigDecimal"))
+ assert(types(7).equals("class java.lang.Double"))
+ assert(types(8).equals("class java.lang.Double"))
+ assert(rows(0).getBoolean(0) == false)
+ assert(rows(0).getLong(1) == 0x225)
+ assert(rows(0).getInt(2) == 17)
+ assert(rows(0).getInt(3) == 77777)
+ assert(rows(0).getInt(4) == 123456789)
+ assert(rows(0).getLong(5) == 123456789012345L)
+ val bd = new BigDecimal("123456789012345.12345678901234500000")
+ assert(rows(0).getAs[BigDecimal](6).equals(bd))
+ assert(rows(0).getDouble(7) == 42.75)
+ assert(rows(0).getDouble(8) == 1.0000000000000002)
+ }
+
+ test("Date types") {
+ val df = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 5)
+ assert(types(0).equals("class java.sql.Date"))
+ assert(types(1).equals("class java.sql.Timestamp"))
+ assert(types(2).equals("class java.sql.Timestamp"))
+ assert(types(3).equals("class java.sql.Timestamp"))
+ assert(types(4).equals("class java.sql.Date"))
+ assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09")))
+ assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24")))
+ assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45")))
+ assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30")))
+ assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01")))
+ }
+
+ test("String types") {
+ val df = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 9)
+ assert(types(0).equals("class java.lang.String"))
+ assert(types(1).equals("class java.lang.String"))
+ assert(types(2).equals("class java.lang.String"))
+ assert(types(3).equals("class java.lang.String"))
+ assert(types(4).equals("class java.lang.String"))
+ assert(types(5).equals("class java.lang.String"))
+ assert(types(6).equals("class [B"))
+ assert(types(7).equals("class [B"))
+ assert(types(8).equals("class [B"))
+ assert(rows(0).getString(0).equals("the"))
+ assert(rows(0).getString(1).equals("quick"))
+ assert(rows(0).getString(2).equals("brown"))
+ assert(rows(0).getString(3).equals("fox"))
+ assert(rows(0).getString(4).equals("jumps"))
+ assert(rows(0).getString(5).equals("over"))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103)))
+ }
+
+ test("Basic write test") {
+ val df1 = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties)
+ val df2 = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties)
+ val df3 = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties)
+ df1.write.jdbc(db.jdbcUrl, "numberscopy", new Properties)
+ df2.write.jdbc(db.jdbcUrl, "datescopy", new Properties)
+ df3.write.jdbc(db.jdbcUrl, "stringscopy", new Properties)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
new file mode 100644
index 000000000000..02d2a25d20b0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.Connection
+import java.util.Properties
+
+class PostgresIntegrationSuite extends DatabaseIntegrationSuite {
+ val db = new DatabaseOnDocker {
+ val imageName = "postgres:latest"
+ val env = Seq("POSTGRES_PASSWORD=rootpass")
+ lazy val jdbcUrl = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass"
+ }
+
+ override def dataPreparation(conn: Connection) {
+ conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
+ conn.setCatalog("foo")
+ conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
+ + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
+ conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ }
+
+ test("Type mapping for various types") {
+ val df = sqlContext.read.jdbc(db.jdbcUrl, "bar", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 10)
+ assert(types(0).equals("class java.lang.String"))
+ assert(types(1).equals("class java.lang.Integer"))
+ assert(types(2).equals("class java.lang.Double"))
+ assert(types(3).equals("class java.lang.Long"))
+ assert(types(4).equals("class java.lang.Boolean"))
+ assert(types(5).equals("class [B"))
+ assert(types(6).equals("class [B"))
+ assert(types(7).equals("class java.lang.Boolean"))
+ assert(types(8).equals("class java.lang.String"))
+ assert(types(9).equals("class java.lang.String"))
+ assert(rows(0).getString(0).equals("hello"))
+ assert(rows(0).getInt(1) == 42)
+ assert(rows(0).getDouble(2) == 1.25)
+ assert(rows(0).getLong(3) == 123456789012345L)
+ assert(rows(0).getBoolean(4) == false)
+ // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5),
+ Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6),
+ Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
+ assert(rows(0).getBoolean(7) == true)
+ assert(rows(0).getString(8) == "172.16.0.42")
+ assert(rows(0).getString(9) == "192.168.0.0/16")
+ }
+
+ test("Basic write test") {
+ val df = sqlContext.read.jdbc(db.jdbcUrl, "bar", new Properties)
+ df.write.jdbc(db.jdbcUrl, "public.barcopy", new Properties)
+ // Test only that it doesn't bomb out.
+ }
+}