From 134ebba01112b4ecc5c18bbc09ff09217438cf52 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 11 Aug 2015 15:00:23 +0800 Subject: [PATCH 1/9] revert 6136 and enable jdbc tests --- sql/core/pom.xml | 12 + .../apache/spark/sql/jdbc/DockerHacks.scala | 51 ++++ .../spark/sql/jdbc/MySQLIntegration.scala | 228 ++++++++++++++++++ .../spark/sql/jdbc/PostgresIntegration.scala | 147 +++++++++++ 4 files changed, 438 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 349007789f63..fe15e4709dab 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -111,6 +111,18 @@ mockito-core test + + com.spotify + docker-client + 2.7.5 + test + + + guava + com.google.guava + + + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala new file mode 100644 index 000000000000..f332cb389f33 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import scala.collection.mutable.MutableList + +import com.spotify.docker.client._ + +/** + * A factory and morgue for DockerClient objects. In the DockerClient we use, + * calling close() closes the desired DockerClient but also renders all other + * DockerClients inoperable. This is inconvenient if we have more than one + * open, such as during tests. + */ +object DockerClientFactory { + var numClients: Int = 0 + val zombies = new MutableList[DockerClient]() + + def get(): DockerClient = { + this.synchronized { + numClients = numClients + 1 + DefaultDockerClient.fromEnv.build() + } + } + + def close(dc: DockerClient) { + this.synchronized { + numClients = numClients - 1 + zombies += dc + if (numClients == 0) { + zombies.foreach(_.close()) + zombies.clear() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala new file mode 100644 index 000000000000..10b3d439fffb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Date, Timestamp} + +import com.spotify.docker.client.DockerClient +import com.spotify.docker.client.messages.ContainerConfig +import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} + +import org.apache.spark.sql.test._ + +class MySQLDatabase { + val docker: DockerClient = DockerClientFactory.get() + val containerId = { + println("Pulling mysql") + docker.pull("mysql") + println("Configuring container") + val config = ContainerConfig.builder().image("mysql") + .env("MYSQL_ROOT_PASSWORD=rootpass") + .build() + println("Creating container") + val id = docker.createContainer(config).id + println("Starting container " + id) + docker.startContainer(id) + id + } + val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + + def close() { + try { + println("Killing container " + containerId) + docker.killContainer(containerId) + println("Removing container " + containerId) + docker.removeContainer(containerId) + println("Closing docker client") + DockerClientFactory.close(docker) + } catch { + case e: Exception => + println(e) + println("You may need to clean this up manually.") + throw e + } + } +} + +class MySQLIntegration extends FunSuite with BeforeAndAfterAll { + var ip: String = null + + def url(ip: String): String = url(ip, "mysql") + def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass" + + def waitForDatabase(ip: String, maxMillis: Long) { + println("Waiting for database to start up.") + val before = System.currentTimeMillis() + var lastException: java.sql.SQLException = null + while (true) { + if (System.currentTimeMillis() > before + maxMillis) { + throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException) + } + try { + val conn = java.sql.DriverManager.getConnection(url(ip)) + conn.close() + println("Database is up.") + return; + } catch { + case e: java.sql.SQLException => + lastException = e + java.lang.Thread.sleep(250) + } + } + } + + def setupDatabase(ip: String) { + val conn = java.sql.DriverManager.getConnection(url(ip)) + try { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } finally { + conn.close() + } + } + + var db: MySQLDatabase = null + + override def beforeAll() { + // If you load the MySQL driver here, DriverManager will deadlock. The + // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres + // and H2 drivers. + //Class.forName("com.mysql.jdbc.Driver") + + db = new MySQLDatabase() + waitForDatabase(db.ip, 60000) + setupDatabase(db.ip) + ip = db.ip + } + + override def afterAll() { + db.close() + } + + test("Basic test") { + val df = TestSQLContext.jdbc(url(ip, "foo"), "tbl") + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = TestSQLContext.jdbc(url(ip, "foo"), "numbers") + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + println(types(1)) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = TestSQLContext.jdbc(url(ip, "foo"), "dates") + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(new Date(91, 10, 9))) + assert(rows(0).getAs[Timestamp](1).equals(new Timestamp(70, 0, 1, 13, 31, 24, 0))) + assert(rows(0).getAs[Timestamp](2).equals(new Timestamp(96, 0, 1, 1, 23, 45, 0))) + assert(rows(0).getAs[Timestamp](3).equals(new Timestamp(109, 1, 13, 23, 31, 30, 0))) + assert(rows(0).getAs[Date](4).equals(new Date(101, 0, 1))) + } + + test("String types") { + val df = TestSQLContext.jdbc(url(ip, "foo"), "strings") + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = TestSQLContext.jdbc(url(ip, "foo"), "numbers") + val df2 = TestSQLContext.jdbc(url(ip, "foo"), "dates") + val df3 = TestSQLContext.jdbc(url(ip, "foo"), "strings") + df1.createJDBCTable(url(ip, "foo"), "numberscopy", false) + df2.createJDBCTable(url(ip, "foo"), "datescopy", false) + df3.createJDBCTable(url(ip, "foo"), "stringscopy", false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala new file mode 100644 index 000000000000..4e4f3800ae46 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.DriverManager + +import com.spotify.docker.client.DockerClient +import com.spotify.docker.client.messages.ContainerConfig +import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} + +import org.apache.spark.sql.test._ + +class PostgresDatabase { + val docker: DockerClient = DockerClientFactory.get() + val containerId = { + println("Pulling postgres") + docker.pull("postgres") + println("Configuring container") + val config = ContainerConfig.builder().image("postgres") + .env("POSTGRES_PASSWORD=rootpass") + .build() + println("Creating container") + val id = docker.createContainer(config).id + println("Starting container " + id) + docker.startContainer(id) + id + } + val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + + def close() { + try { + println("Killing container " + containerId) + docker.killContainer(containerId) + println("Removing container " + containerId) + docker.removeContainer(containerId) + println("Closing docker client") + DockerClientFactory.close(docker) + } catch { + case e: Exception => + println(e) + println("You may need to clean this up manually.") + throw e + } + } +} + +class PostgresIntegration extends FunSuite with BeforeAndAfterAll { + lazy val db = new PostgresDatabase() + + def url(ip: String) = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" + + def waitForDatabase(ip: String, maxMillis: Long) { + val before = System.currentTimeMillis() + var lastException: java.sql.SQLException = null + while (true) { + if (System.currentTimeMillis() > before + maxMillis) { + throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", + lastException) + } + try { + val conn = java.sql.DriverManager.getConnection(url(ip)) + conn.close() + println("Database is up.") + return; + } catch { + case e: java.sql.SQLException => + lastException = e + java.lang.Thread.sleep(250) + } + } + } + + def setupDatabase(ip: String) { + val conn = DriverManager.getConnection(url(ip)) + try { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " + + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + } finally { + conn.close() + } + } + + override def beforeAll() { + println("Waiting for database to start up.") + waitForDatabase(db.ip, 60000) + println("Setting up database.") + setupDatabase(db.ip) + } + + override def afterAll() { + db.close() + } + + test("Type mapping for various types") { + val df = TestSQLContext.jdbc(url(db.ip), "public.bar") + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 10) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Double")) + assert(types(3).equals("class java.lang.Long")) + assert(types(4).equals("class java.lang.Boolean")) + assert(types(5).equals("class [B")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class java.lang.Boolean")) + assert(types(8).equals("class java.lang.String")) + assert(types(9).equals("class java.lang.String")) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(rows(0).getBoolean(4) == false) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49,48,48,48,49,48,48,49,48,49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7) == true) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + } + + test("Basic write test") { + val df = TestSQLContext.jdbc(url(db.ip), "public.bar") + df.createJDBCTable(url(db.ip), "public.barcopy", false) + // Test only that it doesn't bomb out. + } +} From 32711ca61aaf13c1c4c0a80cd7a2162527e0ffd5 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 11 Aug 2015 18:31:25 +0800 Subject: [PATCH 2/9] fix style --- .../spark/sql/jdbc/MySQLIntegration.scala | 38 ++++++++++-------- .../spark/sql/jdbc/PostgresIntegration.scala | 40 ++++++++++--------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala index 10b3d439fffb..da43a9346317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala @@ -22,22 +22,23 @@ import java.sql.{Date, Timestamp} import com.spotify.docker.client.DockerClient import com.spotify.docker.client.messages.ContainerConfig -import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} +import org.apache.spark.{SparkFunSuite, Logging} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.test._ -class MySQLDatabase { +class MySQLDatabase extends Logging { val docker: DockerClient = DockerClientFactory.get() val containerId = { - println("Pulling mysql") + logInfo("Pulling mysql") docker.pull("mysql") - println("Configuring container") + logInfo("Configuring container") val config = ContainerConfig.builder().image("mysql") .env("MYSQL_ROOT_PASSWORD=rootpass") .build() - println("Creating container") + logInfo("Creating container") val id = docker.createContainer(config).id - println("Starting container " + id) + logInfo("Starting container " + id) docker.startContainer(id) id } @@ -45,29 +46,29 @@ class MySQLDatabase { def close() { try { - println("Killing container " + containerId) + logInfo("Killing container " + containerId) docker.killContainer(containerId) - println("Removing container " + containerId) + logInfo("Removing container " + containerId) docker.removeContainer(containerId) - println("Closing docker client") + logInfo("Closing docker client") DockerClientFactory.close(docker) } catch { case e: Exception => - println(e) - println("You may need to clean this up manually.") + logInfo(e.getMessage) + logInfo("You may need to clean this up manually.") throw e } } } -class MySQLIntegration extends FunSuite with BeforeAndAfterAll { +class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { var ip: String = null def url(ip: String): String = url(ip, "mysql") def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass" def waitForDatabase(ip: String, maxMillis: Long) { - println("Waiting for database to start up.") + logInfo("Waiting for database to start up.") val before = System.currentTimeMillis() var lastException: java.sql.SQLException = null while (true) { @@ -77,7 +78,7 @@ class MySQLIntegration extends FunSuite with BeforeAndAfterAll { try { val conn = java.sql.DriverManager.getConnection(url(ip)) conn.close() - println("Database is up.") + logInfo("Database is up.") return; } catch { case e: java.sql.SQLException => @@ -111,7 +112,8 @@ class MySQLIntegration extends FunSuite with BeforeAndAfterAll { conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" ).executeUpdate() - conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() } finally { conn.close() } @@ -123,7 +125,9 @@ class MySQLIntegration extends FunSuite with BeforeAndAfterAll { // If you load the MySQL driver here, DriverManager will deadlock. The // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres // and H2 drivers. - //Class.forName("com.mysql.jdbc.Driver") + // scalastyle:off classforname + // Class.forName("com.mysql.jdbc.Driver") + // scalastyle:on classforname db = new MySQLDatabase() waitForDatabase(db.ip, 60000) @@ -151,7 +155,7 @@ class MySQLIntegration extends FunSuite with BeforeAndAfterAll { assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types.length == 9) - println(types(1)) + logInfo(types(1)) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Long")) assert(types(2).equals("class java.lang.Integer")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala index 4e4f3800ae46..523faa90eb0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala @@ -21,22 +21,23 @@ import java.sql.DriverManager import com.spotify.docker.client.DockerClient import com.spotify.docker.client.messages.ContainerConfig -import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} +import org.apache.spark.{SparkFunSuite, Logging} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.test._ -class PostgresDatabase { +class PostgresDatabase extends Logging { val docker: DockerClient = DockerClientFactory.get() val containerId = { - println("Pulling postgres") + logInfo("Pulling postgres") docker.pull("postgres") - println("Configuring container") + logInfo("Configuring container") val config = ContainerConfig.builder().image("postgres") .env("POSTGRES_PASSWORD=rootpass") .build() - println("Creating container") + logInfo("Creating container") val id = docker.createContainer(config).id - println("Starting container " + id) + logInfo("Starting container " + id) docker.startContainer(id) id } @@ -44,25 +45,26 @@ class PostgresDatabase { def close() { try { - println("Killing container " + containerId) + logInfo("Killing container " + containerId) docker.killContainer(containerId) - println("Removing container " + containerId) + logInfo("Removing container " + containerId) docker.removeContainer(containerId) - println("Closing docker client") + logInfo("Closing docker client") DockerClientFactory.close(docker) } catch { case e: Exception => - println(e) - println("You may need to clean this up manually.") + logInfo(e.getMessage) + logInfo("You may need to clean this up manually.") throw e } } } -class PostgresIntegration extends FunSuite with BeforeAndAfterAll { +class PostgresIntegration extends SparkFunSuite with BeforeAndAfterAll { lazy val db = new PostgresDatabase() - def url(ip: String) = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" + def url(ip: String): String = + s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" def waitForDatabase(ip: String, maxMillis: Long) { val before = System.currentTimeMillis() @@ -75,7 +77,7 @@ class PostgresIntegration extends FunSuite with BeforeAndAfterAll { try { val conn = java.sql.DriverManager.getConnection(url(ip)) conn.close() - println("Database is up.") + logInfo("Database is up.") return; } catch { case e: java.sql.SQLException => @@ -100,9 +102,9 @@ class PostgresIntegration extends FunSuite with BeforeAndAfterAll { } override def beforeAll() { - println("Waiting for database to start up.") + logInfo("Waiting for database to start up.") waitForDatabase(db.ip, 60000) - println("Setting up database.") + logInfo("Setting up database.") setupDatabase(db.ip) } @@ -132,8 +134,10 @@ class PostgresIntegration extends FunSuite with BeforeAndAfterAll { assert(rows(0).getLong(3) == 123456789012345L) assert(rows(0).getBoolean(4) == false) // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49,48,48,48,49,48,48,49,48,49))) - assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) assert(rows(0).getBoolean(7) == true) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") From cefb3bc8461290134a48e9fb19b62e15e9912bcd Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 12 Aug 2015 22:05:52 +0800 Subject: [PATCH 3/9] bug fix --- .../spark/sql/jdbc/MySQLIntegration.scala | 34 ++++++------------- .../spark/sql/jdbc/PostgresIntegration.scala | 34 ++++++------------- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala index da43a9346317..e0b1b5020556 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala @@ -20,44 +20,33 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal import java.sql.{Date, Timestamp} +import org.scalatest.BeforeAndAfterAll + import com.spotify.docker.client.DockerClient import com.spotify.docker.client.messages.ContainerConfig -import org.apache.spark.{SparkFunSuite, Logging} -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ -class MySQLDatabase extends Logging { +class MySQLDatabase { val docker: DockerClient = DockerClientFactory.get() + val containerId = { - logInfo("Pulling mysql") docker.pull("mysql") - logInfo("Configuring container") val config = ContainerConfig.builder().image("mysql") .env("MYSQL_ROOT_PASSWORD=rootpass") .build() - logInfo("Creating container") val id = docker.createContainer(config).id - logInfo("Starting container " + id) docker.startContainer(id) id } + val ip = docker.inspectContainer(containerId).networkSettings.ipAddress def close() { - try { - logInfo("Killing container " + containerId) - docker.killContainer(containerId) - logInfo("Removing container " + containerId) - docker.removeContainer(containerId) - logInfo("Closing docker client") - DockerClientFactory.close(docker) - } catch { - case e: Exception => - logInfo(e.getMessage) - logInfo("You may need to clean this up manually.") - throw e - } + docker.killContainer(containerId) + docker.removeContainer(containerId) + DockerClientFactory.close(docker) } } @@ -68,7 +57,6 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass" def waitForDatabase(ip: String, maxMillis: Long) { - logInfo("Waiting for database to start up.") val before = System.currentTimeMillis() var lastException: java.sql.SQLException = null while (true) { @@ -78,8 +66,7 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { try { val conn = java.sql.DriverManager.getConnection(url(ip)) conn.close() - logInfo("Database is up.") - return; + return } catch { case e: java.sql.SQLException => lastException = e @@ -155,7 +142,6 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) assert(types.length == 9) - logInfo(types(1)) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Long")) assert(types(2).equals("class java.lang.Integer")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala index 523faa90eb0f..2ded203135ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala @@ -19,44 +19,33 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager +import org.scalatest.BeforeAndAfterAll + import com.spotify.docker.client.DockerClient import com.spotify.docker.client.messages.ContainerConfig -import org.apache.spark.{SparkFunSuite, Logging} -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ -class PostgresDatabase extends Logging { +class PostgresDatabase { val docker: DockerClient = DockerClientFactory.get() + val containerId = { - logInfo("Pulling postgres") docker.pull("postgres") - logInfo("Configuring container") val config = ContainerConfig.builder().image("postgres") .env("POSTGRES_PASSWORD=rootpass") .build() - logInfo("Creating container") val id = docker.createContainer(config).id - logInfo("Starting container " + id) docker.startContainer(id) id } + val ip = docker.inspectContainer(containerId).networkSettings.ipAddress def close() { - try { - logInfo("Killing container " + containerId) - docker.killContainer(containerId) - logInfo("Removing container " + containerId) - docker.removeContainer(containerId) - logInfo("Closing docker client") - DockerClientFactory.close(docker) - } catch { - case e: Exception => - logInfo(e.getMessage) - logInfo("You may need to clean this up manually.") - throw e - } + docker.killContainer(containerId) + docker.removeContainer(containerId) + DockerClientFactory.close(docker) } } @@ -77,8 +66,7 @@ class PostgresIntegration extends SparkFunSuite with BeforeAndAfterAll { try { val conn = java.sql.DriverManager.getConnection(url(ip)) conn.close() - logInfo("Database is up.") - return; + return } catch { case e: java.sql.SQLException => lastException = e @@ -102,9 +90,7 @@ class PostgresIntegration extends SparkFunSuite with BeforeAndAfterAll { } override def beforeAll() { - logInfo("Waiting for database to start up.") waitForDatabase(db.ip, 60000) - logInfo("Setting up database.") setupDatabase(db.ip) } From 0569be8a4338b3ce139c988f7a8d588f3b0aeb15 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 13 Aug 2015 01:45:42 +0800 Subject: [PATCH 4/9] address comments --- ...ration.scala => MySQLIntegrationSuite.scala} | 17 +++++++++-------- ...ion.scala => PostgresIntegrationSuite.scala} | 7 ++++--- 2 files changed, 13 insertions(+), 11 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/jdbc/{MySQLIntegration.scala => MySQLIntegrationSuite.scala} (92%) rename sql/core/src/test/scala/org/apache/spark/sql/jdbc/{PostgresIntegration.scala => PostgresIntegrationSuite.scala} (94%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala similarity index 92% rename from sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala rename to sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index e0b1b5020556..ecc119fc720f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal import java.sql.{Date, Timestamp} +import java.util.Properties import org.scalatest.BeforeAndAfterAll @@ -50,7 +51,7 @@ class MySQLDatabase { } } -class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { +class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { var ip: String = null def url(ip: String): String = url(ip, "mysql") @@ -127,7 +128,7 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("Basic test") { - val df = TestSQLContext.jdbc(url(ip, "foo"), "tbl") + val df = TestSQLContext.read.jdbc(url(ip, "foo"), "tbl", new Properties) val rows = df.collect() assert(rows.length == 2) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -137,7 +138,7 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("Numeric types") { - val df = TestSQLContext.jdbc(url(ip, "foo"), "numbers") + val df = TestSQLContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -164,7 +165,7 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("Date types") { - val df = TestSQLContext.jdbc(url(ip, "foo"), "dates") + val df = TestSQLContext.read.jdbc(url(ip, "foo"), "dates", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -182,7 +183,7 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("String types") { - val df = TestSQLContext.jdbc(url(ip, "foo"), "strings") + val df = TestSQLContext.read.jdbc(url(ip, "foo"), "strings", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -208,9 +209,9 @@ class MySQLIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("Basic write test") { - val df1 = TestSQLContext.jdbc(url(ip, "foo"), "numbers") - val df2 = TestSQLContext.jdbc(url(ip, "foo"), "dates") - val df3 = TestSQLContext.jdbc(url(ip, "foo"), "strings") + val df1 = TestSQLContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) + val df2 = TestSQLContext.read.jdbc(url(ip, "foo"), "dates", new Properties) + val df3 = TestSQLContext.read.jdbc(url(ip, "foo"), "strings", new Properties) df1.createJDBCTable(url(ip, "foo"), "numberscopy", false) df2.createJDBCTable(url(ip, "foo"), "datescopy", false) df3.createJDBCTable(url(ip, "foo"), "stringscopy", false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala rename to sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 2ded203135ca..42f0c3eedb28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager +import java.util.Properties import org.scalatest.BeforeAndAfterAll @@ -49,7 +50,7 @@ class PostgresDatabase { } } -class PostgresIntegration extends SparkFunSuite with BeforeAndAfterAll { +class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { lazy val db = new PostgresDatabase() def url(ip: String): String = @@ -99,7 +100,7 @@ class PostgresIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("Type mapping for various types") { - val df = TestSQLContext.jdbc(url(db.ip), "public.bar") + val df = TestSQLContext.read.jdbc(url(db.ip), "public.bar", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -130,7 +131,7 @@ class PostgresIntegration extends SparkFunSuite with BeforeAndAfterAll { } test("Basic write test") { - val df = TestSQLContext.jdbc(url(db.ip), "public.bar") + val df = TestSQLContext.read.jdbc(url(db.ip), "public.bar", new Properties) df.createJDBCTable(url(db.ip), "public.barcopy", false) // Test only that it doesn't bomb out. } From bcfb9a2b57861361fc796f04619286d7a5e6f9de Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 13 Aug 2015 02:01:01 +0800 Subject: [PATCH 5/9] fix deprecated api useage --- .../org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index ecc119fc720f..42c891523bf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -212,8 +212,8 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { val df1 = TestSQLContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) val df2 = TestSQLContext.read.jdbc(url(ip, "foo"), "dates", new Properties) val df3 = TestSQLContext.read.jdbc(url(ip, "foo"), "strings", new Properties) - df1.createJDBCTable(url(ip, "foo"), "numberscopy", false) - df2.createJDBCTable(url(ip, "foo"), "datescopy", false) - df3.createJDBCTable(url(ip, "foo"), "stringscopy", false) + df1.write.jdbc(url(ip, "foo"), "numberscopy", new Properties) + df2.write.jdbc(url(ip, "foo"), "datescopy", new Properties) + df3.write.jdbc(url(ip, "foo"), "stringscopy", new Properties) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 42f0c3eedb28..1c27ccd1fbbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -132,7 +132,7 @@ class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { test("Basic write test") { val df = TestSQLContext.read.jdbc(url(db.ip), "public.bar", new Properties) - df.createJDBCTable(url(db.ip), "public.barcopy", false) + df.write.jdbc(url(db.ip), "public.barcopy", new Properties) // Test only that it doesn't bomb out. } } From be26053d3795c0e983793dae008adfd167418976 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 14 Aug 2015 17:55:18 +0800 Subject: [PATCH 6/9] pull docker image with retries --- .../sql/jdbc/MySQLIntegrationSuite.scala | 54 ++++++++++++------- .../sql/jdbc/PostgresIntegrationSuite.scala | 39 ++++++++++---- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 42c891523bf0..044ab52bcaf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.scalatest.BeforeAndAfterAll -import com.spotify.docker.client.DockerClient +import com.spotify.docker.client.{ImageNotFoundException, DockerClient} import com.spotify.docker.client.messages.ContainerConfig import org.apache.spark.SparkFunSuite @@ -31,20 +31,37 @@ import org.apache.spark.sql.test._ class MySQLDatabase { val docker: DockerClient = DockerClientFactory.get() + var containerId: String = null - val containerId = { - docker.pull("mysql") - val config = ContainerConfig.builder().image("mysql") - .env("MYSQL_ROOT_PASSWORD=rootpass") - .build() - val id = docker.createContainer(config).id - docker.startContainer(id) - id + start() + + def start(): Unit = { + while (true) { + try { + val config = ContainerConfig.builder() + .image("mysql").env("MYSQL_ROOT_PASSWORD=rootpass") + .build() + containerId = docker.createContainer(config).id + docker.startContainer(containerId) + return + } catch { + case e: ImageNotFoundException => retry(3)(docker.pull("mysql")) + } + } } - val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + private def retry[T](n: Int)(fn: => T): T = { + try { + fn + } catch { + case e if n > 1 => + retry(n - 1)(fn) + } + } - def close() { + lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + + def close(): Unit = { docker.killContainer(containerId) docker.removeContainer(containerId) DockerClientFactory.close(docker) @@ -52,6 +69,7 @@ class MySQLDatabase { } class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { + lazy val db: MySQLDatabase = new MySQLDatabase() var ip: String = null def url(ip: String): String = url(ip, "mysql") @@ -107,8 +125,6 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } } - var db: MySQLDatabase = null - override def beforeAll() { // If you load the MySQL driver here, DriverManager will deadlock. The // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres @@ -116,8 +132,6 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { // scalastyle:off classforname // Class.forName("com.mysql.jdbc.Driver") // scalastyle:on classforname - - db = new MySQLDatabase() waitForDatabase(db.ip, 60000) setupDatabase(db.ip) ip = db.ip @@ -175,11 +189,11 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { assert(types(2).equals("class java.sql.Timestamp")) assert(types(3).equals("class java.sql.Timestamp")) assert(types(4).equals("class java.sql.Date")) - assert(rows(0).getAs[Date](0).equals(new Date(91, 10, 9))) - assert(rows(0).getAs[Timestamp](1).equals(new Timestamp(70, 0, 1, 13, 31, 24, 0))) - assert(rows(0).getAs[Timestamp](2).equals(new Timestamp(96, 0, 1, 1, 23, 45, 0))) - assert(rows(0).getAs[Timestamp](3).equals(new Timestamp(109, 1, 13, 23, 31, 30, 0))) - assert(rows(0).getAs[Date](4).equals(new Date(101, 0, 1))) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) } test("String types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 1c27ccd1fbbf..7fe9b94d6311 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -22,7 +22,7 @@ import java.util.Properties import org.scalatest.BeforeAndAfterAll -import com.spotify.docker.client.DockerClient +import com.spotify.docker.client.{ImageNotFoundException, DockerClient} import com.spotify.docker.client.messages.ContainerConfig import org.apache.spark.SparkFunSuite @@ -30,20 +30,37 @@ import org.apache.spark.sql.test._ class PostgresDatabase { val docker: DockerClient = DockerClientFactory.get() + var containerId: String = null - val containerId = { - docker.pull("postgres") - val config = ContainerConfig.builder().image("postgres") - .env("POSTGRES_PASSWORD=rootpass") - .build() - val id = docker.createContainer(config).id - docker.startContainer(id) - id + start() + + def start(): Unit = { + while (true) { + try { + val config = ContainerConfig.builder() + .image("postgres").env("POSTGRES_PASSWORD=rootpass") + .build() + containerId = docker.createContainer(config).id + docker.startContainer(containerId) + return + } catch { + case e: ImageNotFoundException => retry(3)(docker.pull("postgres")) + } + } + } + + private def retry[T](n: Int)(fn: => T): T = { + try { + fn + } catch { + case e if n > 1 => + retry(n - 1)(fn) + } } - val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - def close() { + def close(): Unit = { docker.killContainer(containerId) docker.removeContainer(containerId) DockerClientFactory.close(docker) From a31a6281307d0faa7bdbd9c3680e2ce432c3390a Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 14 Aug 2015 19:01:30 +0800 Subject: [PATCH 7/9] use SharedSQLContext --- .../sql/jdbc/MySQLIntegrationSuite.scala | 19 ++++++++++--------- .../sql/jdbc/PostgresIntegrationSuite.scala | 9 +++++---- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 044ab52bcaf7..7945521da423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -27,7 +27,7 @@ import com.spotify.docker.client.{ImageNotFoundException, DockerClient} import com.spotify.docker.client.messages.ContainerConfig import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ +import org.apache.spark.sql.test.SharedSQLContext class MySQLDatabase { val docker: DockerClient = DockerClientFactory.get() @@ -68,7 +68,7 @@ class MySQLDatabase { } } -class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { +class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with SharedSQLContext { lazy val db: MySQLDatabase = new MySQLDatabase() var ip: String = null @@ -132,6 +132,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { // scalastyle:off classforname // Class.forName("com.mysql.jdbc.Driver") // scalastyle:on classforname + super.beforeAll() waitForDatabase(db.ip, 60000) setupDatabase(db.ip) ip = db.ip @@ -142,7 +143,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("Basic test") { - val df = TestSQLContext.read.jdbc(url(ip, "foo"), "tbl", new Properties) + val df = sqlContext.read.jdbc(url(ip, "foo"), "tbl", new Properties) val rows = df.collect() assert(rows.length == 2) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -152,7 +153,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("Numeric types") { - val df = TestSQLContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) + val df = sqlContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -179,7 +180,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("Date types") { - val df = TestSQLContext.read.jdbc(url(ip, "foo"), "dates", new Properties) + val df = sqlContext.read.jdbc(url(ip, "foo"), "dates", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -197,7 +198,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("String types") { - val df = TestSQLContext.read.jdbc(url(ip, "foo"), "strings", new Properties) + val df = sqlContext.read.jdbc(url(ip, "foo"), "strings", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -223,9 +224,9 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("Basic write test") { - val df1 = TestSQLContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) - val df2 = TestSQLContext.read.jdbc(url(ip, "foo"), "dates", new Properties) - val df3 = TestSQLContext.read.jdbc(url(ip, "foo"), "strings", new Properties) + val df1 = sqlContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) + val df2 = sqlContext.read.jdbc(url(ip, "foo"), "dates", new Properties) + val df3 = sqlContext.read.jdbc(url(ip, "foo"), "strings", new Properties) df1.write.jdbc(url(ip, "foo"), "numberscopy", new Properties) df2.write.jdbc(url(ip, "foo"), "datescopy", new Properties) df3.write.jdbc(url(ip, "foo"), "stringscopy", new Properties) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 7fe9b94d6311..127d40fb095d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -26,7 +26,7 @@ import com.spotify.docker.client.{ImageNotFoundException, DockerClient} import com.spotify.docker.client.messages.ContainerConfig import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ +import org.apache.spark.sql.test.SharedSQLContext class PostgresDatabase { val docker: DockerClient = DockerClientFactory.get() @@ -67,7 +67,7 @@ class PostgresDatabase { } } -class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { +class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with SharedSQLContext { lazy val db = new PostgresDatabase() def url(ip: String): String = @@ -108,6 +108,7 @@ class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } override def beforeAll() { + super.beforeAll() waitForDatabase(db.ip, 60000) setupDatabase(db.ip) } @@ -117,7 +118,7 @@ class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("Type mapping for various types") { - val df = TestSQLContext.read.jdbc(url(db.ip), "public.bar", new Properties) + val df = sqlContext.read.jdbc(url(db.ip), "public.bar", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -148,7 +149,7 @@ class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll { } test("Basic write test") { - val df = TestSQLContext.read.jdbc(url(db.ip), "public.bar", new Properties) + val df = sqlContext.read.jdbc(url(db.ip), "public.bar", new Properties) df.write.jdbc(url(db.ip), "public.barcopy", new Properties) // Test only that it doesn't bomb out. } From 502f8ad37d917411739386f453c607c7f182dfe2 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 15 Aug 2015 17:08:16 +0800 Subject: [PATCH 8/9] add image tag --- .../scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala | 2 +- .../org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 7945521da423..d95b8a0be5ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -45,7 +45,7 @@ class MySQLDatabase { docker.startContainer(containerId) return } catch { - case e: ImageNotFoundException => retry(3)(docker.pull("mysql")) + case e: ImageNotFoundException => retry(3)(docker.pull("mysql:latest")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 127d40fb095d..b8d81a420b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -44,7 +44,7 @@ class PostgresDatabase { docker.startContainer(containerId) return } catch { - case e: ImageNotFoundException => retry(3)(docker.pull("postgres")) + case e: ImageNotFoundException => retry(3)(docker.pull("postgres:latest")) } } } From fbc471bd8de9d7742b4329f79610ab9dd8fe123c Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 2 Sep 2015 12:55:00 +0800 Subject: [PATCH 9/9] refactoring --- .../apache/spark/sql/jdbc/DockerHacks.scala | 115 ++++++++++++ .../sql/jdbc/MySQLIntegrationSuite.scala | 168 +++++------------- .../sql/jdbc/PostgresIntegrationSuite.scala | 113 ++---------- 3 files changed, 171 insertions(+), 225 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala index f332cb389f33..41c6f7ade3c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala @@ -17,10 +17,125 @@ package org.apache.spark.sql.jdbc +import java.sql.Connection + +import scala.collection.JavaConverters._ import scala.collection.mutable.MutableList +import com.spotify.docker.client.messages.ContainerConfig import com.spotify.docker.client._ +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext +import org.scalatest.BeforeAndAfterAll + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled + */ + def imageName: String + + /** + * A Seq of environment variables in the form of VAR=value + */ + def env: Seq[String] + + /** + * jdbcUrl should be a lazy val or a function since `ip` it relies on is only available after + * the docker container starts + */ + def jdbcUrl: String + + private val docker: DockerClient = DockerClientFactory.get() + private var containerId: String = null + + lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress + + def start(): Unit = { + while (true) { + try { + val config = ContainerConfig.builder() + .image(imageName).env(env.asJava) + .build() + containerId = docker.createContainer(config).id + docker.startContainer(containerId) + return + } catch { + case e: ImageNotFoundException => retry(5)(docker.pull(imageName)) + } + } + } + + private def retry[T](n: Int)(fn: => T): T = { + try { + fn + } catch { + case e if n > 1 => + retry(n - 1)(fn) + } + } + + def close(): Unit = { + docker.killContainer(containerId) + docker.removeContainer(containerId) + DockerClientFactory.close(docker) + } +} + +abstract class DatabaseIntegrationSuite extends SparkFunSuite + with BeforeAndAfterAll with SharedSQLContext { + + def db: DatabaseOnDocker + + def waitForDatabase(ip: String, maxMillis: Long) { + val before = System.currentTimeMillis() + var lastException: java.sql.SQLException = null + while (true) { + if (System.currentTimeMillis() > before + maxMillis) { + throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException) + } + try { + val conn = java.sql.DriverManager.getConnection(db.jdbcUrl) + conn.close() + return + } catch { + case e: java.sql.SQLException => + lastException = e + java.lang.Thread.sleep(250) + } + } + } + + def setupDatabase(ip: String): Unit = { + val conn: Connection = java.sql.DriverManager.getConnection(db.jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } + + /** + * Prepare databases and tables for testing + */ + def dataPreparation(connection: Connection) + + override def beforeAll() { + super.beforeAll() + db.start() + waitForDatabase(db.ip, 60000) + setupDatabase(db.ip) + } + + override def afterAll() { + try { + db.close() + } finally { + super.afterAll() + } + } +} + /** * A factory and morgue for DockerClient objects. In the DockerClient we use, * calling close() closes the desired DockerClient but also renders all other diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index d95b8a0be5ed..b295894c8097 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -18,132 +18,44 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.{Date, Timestamp} +import java.sql.{Connection, Date, Timestamp} import java.util.Properties -import org.scalatest.BeforeAndAfterAll - -import com.spotify.docker.client.{ImageNotFoundException, DockerClient} -import com.spotify.docker.client.messages.ContainerConfig - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext - -class MySQLDatabase { - val docker: DockerClient = DockerClientFactory.get() - var containerId: String = null - - start() - - def start(): Unit = { - while (true) { - try { - val config = ContainerConfig.builder() - .image("mysql").env("MYSQL_ROOT_PASSWORD=rootpass") - .build() - containerId = docker.createContainer(config).id - docker.startContainer(containerId) - return - } catch { - case e: ImageNotFoundException => retry(3)(docker.pull("mysql:latest")) - } - } - } - - private def retry[T](n: Int)(fn: => T): T = { - try { - fn - } catch { - case e if n > 1 => - retry(n - 1)(fn) - } - } - - lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - - def close(): Unit = { - docker.killContainer(containerId) - docker.removeContainer(containerId) - DockerClientFactory.close(docker) - } -} - -class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with SharedSQLContext { - lazy val db: MySQLDatabase = new MySQLDatabase() - var ip: String = null - - def url(ip: String): String = url(ip, "mysql") - def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass" - - def waitForDatabase(ip: String, maxMillis: Long) { - val before = System.currentTimeMillis() - var lastException: java.sql.SQLException = null - while (true) { - if (System.currentTimeMillis() > before + maxMillis) { - throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException) - } - try { - val conn = java.sql.DriverManager.getConnection(url(ip)) - conn.close() - return - } catch { - case e: java.sql.SQLException => - lastException = e - java.lang.Thread.sleep(250) - } - } - } - - def setupDatabase(ip: String) { - val conn = java.sql.DriverManager.getConnection(url(ip)) - try { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate() - conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate() - conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate() - - conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), " - + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " - + "dbl DOUBLE)").executeUpdate() - conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', " - + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " - + "42.75, 1.0000000000000002)").executeUpdate() - - conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " - + "yr YEAR)").executeUpdate() - conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', " - + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() - - // TODO: Test locale conversion for strings. - conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " - + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" - ).executeUpdate() - conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', " + - "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() - } finally { - conn.close() - } - } - - override def beforeAll() { - // If you load the MySQL driver here, DriverManager will deadlock. The - // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres - // and H2 drivers. - // scalastyle:off classforname - // Class.forName("com.mysql.jdbc.Driver") - // scalastyle:on classforname - super.beforeAll() - waitForDatabase(db.ip, 60000) - setupDatabase(db.ip) - ip = db.ip +class MySQLIntegrationSuite extends DatabaseIntegrationSuite { + val db = new DatabaseOnDocker { + val imageName = "mysql:latest" + val env = Seq("MYSQL_ROOT_PASSWORD=rootpass") + lazy val jdbcUrl = s"jdbc:mysql://$ip:3306/mysql?user=root&password=rootpass" } - override def afterAll() { - db.close() + override def dataPreparation(conn: Connection) { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() } test("Basic test") { - val df = sqlContext.read.jdbc(url(ip, "foo"), "tbl", new Properties) + val df = sqlContext.read.jdbc(db.jdbcUrl, "tbl", new Properties) val rows = df.collect() assert(rows.length == 2) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -153,7 +65,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh } test("Numeric types") { - val df = sqlContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) + val df = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -180,7 +92,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh } test("Date types") { - val df = sqlContext.read.jdbc(url(ip, "foo"), "dates", new Properties) + val df = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -198,7 +110,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh } test("String types") { - val df = sqlContext.read.jdbc(url(ip, "foo"), "strings", new Properties) + val df = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -224,11 +136,11 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh } test("Basic write test") { - val df1 = sqlContext.read.jdbc(url(ip, "foo"), "numbers", new Properties) - val df2 = sqlContext.read.jdbc(url(ip, "foo"), "dates", new Properties) - val df3 = sqlContext.read.jdbc(url(ip, "foo"), "strings", new Properties) - df1.write.jdbc(url(ip, "foo"), "numberscopy", new Properties) - df2.write.jdbc(url(ip, "foo"), "datescopy", new Properties) - df3.write.jdbc(url(ip, "foo"), "stringscopy", new Properties) + val df1 = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties) + df1.write.jdbc(db.jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(db.jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(db.jdbcUrl, "stringscopy", new Properties) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index b8d81a420b66..02d2a25d20b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -17,108 +17,27 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager +import java.sql.Connection import java.util.Properties -import org.scalatest.BeforeAndAfterAll - -import com.spotify.docker.client.{ImageNotFoundException, DockerClient} -import com.spotify.docker.client.messages.ContainerConfig - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext - -class PostgresDatabase { - val docker: DockerClient = DockerClientFactory.get() - var containerId: String = null - - start() - - def start(): Unit = { - while (true) { - try { - val config = ContainerConfig.builder() - .image("postgres").env("POSTGRES_PASSWORD=rootpass") - .build() - containerId = docker.createContainer(config).id - docker.startContainer(containerId) - return - } catch { - case e: ImageNotFoundException => retry(3)(docker.pull("postgres:latest")) - } - } - } - - private def retry[T](n: Int)(fn: => T): T = { - try { - fn - } catch { - case e if n > 1 => - retry(n - 1)(fn) - } - } - - lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress - - def close(): Unit = { - docker.killContainer(containerId) - docker.removeContainer(containerId) - DockerClientFactory.close(docker) - } -} - -class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with SharedSQLContext { - lazy val db = new PostgresDatabase() - - def url(ip: String): String = - s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" - - def waitForDatabase(ip: String, maxMillis: Long) { - val before = System.currentTimeMillis() - var lastException: java.sql.SQLException = null - while (true) { - if (System.currentTimeMillis() > before + maxMillis) { - throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", - lastException) - } - try { - val conn = java.sql.DriverManager.getConnection(url(ip)) - conn.close() - return - } catch { - case e: java.sql.SQLException => - lastException = e - java.lang.Thread.sleep(250) - } - } - } - - def setupDatabase(ip: String) { - val conn = DriverManager.getConnection(url(ip)) - try { - conn.prepareStatement("CREATE DATABASE foo").executeUpdate() - conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " - + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() - conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " - + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() - } finally { - conn.close() - } - } - - override def beforeAll() { - super.beforeAll() - waitForDatabase(db.ip, 60000) - setupDatabase(db.ip) +class PostgresIntegrationSuite extends DatabaseIntegrationSuite { + val db = new DatabaseOnDocker { + val imageName = "postgres:latest" + val env = Seq("POSTGRES_PASSWORD=rootpass") + lazy val jdbcUrl = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass" } - override def afterAll() { - db.close() + override def dataPreparation(conn: Connection) { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " + + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() } test("Type mapping for various types") { - val df = sqlContext.read.jdbc(url(db.ip), "public.bar", new Properties) + val df = sqlContext.read.jdbc(db.jdbcUrl, "bar", new Properties) val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) @@ -149,8 +68,8 @@ class PostgresIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with } test("Basic write test") { - val df = sqlContext.read.jdbc(url(db.ip), "public.bar", new Properties) - df.write.jdbc(url(db.ip), "public.barcopy", new Properties) + val df = sqlContext.read.jdbc(db.jdbcUrl, "bar", new Properties) + df.write.jdbc(db.jdbcUrl, "public.barcopy", new Properties) // Test only that it doesn't bomb out. } }