diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 349007789f63..fe15e4709dab 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -111,6 +111,18 @@ mockito-core test + + com.spotify + docker-client + 2.7.5 + test + + + guava + com.google.guava + + + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala new file mode 100644 index 000000000000..41c6f7ade3c6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection + +import scala.collection.JavaConverters._ +import scala.collection.mutable.MutableList + +import com.spotify.docker.client.messages.ContainerConfig +import com.spotify.docker.client._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext +import org.scalatest.BeforeAndAfterAll + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled + */ + def imageName: String + + /** + * A Seq of environment variables in the form of VAR=value + */ + def env: Seq[String] + + /** + * jdbcUrl should be a lazy val or a function since `ip` it relies on is only available after + * the docker container starts + */ + def jdbcUrl: String + + private val docker: DockerClient = DockerClientFactory.get() + private var containerId: String = null + + lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + + def start(): Unit = { + while (true) { + try { + val config = ContainerConfig.builder() + .image(imageName).env(env.asJava) + .build() + containerId = docker.createContainer(config).id + docker.startContainer(containerId) + return + } catch { + case e: ImageNotFoundException => retry(5)(docker.pull(imageName)) + } + } + } + + private def retry[T](n: Int)(fn: => T): T = { + try { + fn + } catch { + case e if n > 1 => + retry(n - 1)(fn) + } + } + + def close(): Unit = { + docker.killContainer(containerId) + docker.removeContainer(containerId) + DockerClientFactory.close(docker) + } +} + +abstract class DatabaseIntegrationSuite extends SparkFunSuite + with BeforeAndAfterAll with SharedSQLContext { + + def db: DatabaseOnDocker + + def waitForDatabase(ip: String, maxMillis: Long) { + val before = System.currentTimeMillis() + var lastException: java.sql.SQLException = null + while (true) { + if (System.currentTimeMillis() > before + maxMillis) { + throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException) + } + try { + val conn = java.sql.DriverManager.getConnection(db.jdbcUrl) + conn.close() + return + } catch { + case e: java.sql.SQLException => + lastException = e + java.lang.Thread.sleep(250) + } + } + } + + def setupDatabase(ip: String): Unit = { + val conn: Connection = java.sql.DriverManager.getConnection(db.jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } + + /** + * Prepare databases and tables for testing + */ + def dataPreparation(connection: Connection) + + override def beforeAll() { + super.beforeAll() + db.start() + waitForDatabase(db.ip, 60000) + setupDatabase(db.ip) + } + + override def afterAll() { + try { + db.close() + } finally { + super.afterAll() + } + } +} + +/** + * A factory and morgue for DockerClient objects. In the DockerClient we use, + * calling close() closes the desired DockerClient but also renders all other + * DockerClients inoperable. This is inconvenient if we have more than one + * open, such as during tests. + */ +object DockerClientFactory { + var numClients: Int = 0 + val zombies = new MutableList[DockerClient]() + + def get(): DockerClient = { + this.synchronized { + numClients = numClients + 1 + DefaultDockerClient.fromEnv.build() + } + } + + def close(dc: DockerClient) { + this.synchronized { + numClients = numClients - 1 + zombies += dc + if (numClients == 0) { + zombies.foreach(_.close()) + zombies.clear() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala new file mode 100644 index 000000000000..b295894c8097 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +class MySQLIntegrationSuite extends DatabaseIntegrationSuite { + val db = new DatabaseOnDocker { + val imageName = "mysql:latest" + val env = Seq("MYSQL_ROOT_PASSWORD=rootpass") + lazy val jdbcUrl = s"jdbc:mysql://$ip:3306/mysql?user=root&password=rootpass" + } + + override def dataPreparation(conn: Connection) { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(db.jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties) + df1.write.jdbc(db.jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(db.jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(db.jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala new file mode 100644 index 000000000000..02d2a25d20b0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +class PostgresIntegrationSuite extends DatabaseIntegrationSuite { + val db = new DatabaseOnDocker { + val imageName = "postgres:latest" + val env = Seq("POSTGRES_PASSWORD=rootpass") + lazy val jdbcUrl = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" + } + + override def dataPreparation(conn: Connection) { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " + + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + } + + test("Type mapping for various types") { + val df = sqlContext.read.jdbc(db.jdbcUrl, "bar", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 10) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Double")) + assert(types(3).equals("class java.lang.Long")) + assert(types(4).equals("class java.lang.Boolean")) + assert(types(5).equals("class [B")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class java.lang.Boolean")) + assert(types(8).equals("class java.lang.String")) + assert(types(9).equals("class java.lang.String")) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(rows(0).getBoolean(4) == false) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7) == true) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + } + + test("Basic write test") { + val df = sqlContext.read.jdbc(db.jdbcUrl, "bar", new Properties) + df.write.jdbc(db.jdbcUrl, "public.barcopy", new Properties) + // Test only that it doesn't bomb out. + } +}