From bcc86c0395ddc24cb629f46af9f985bdff0387a6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 21 Nov 2016 21:49:42 -0800 Subject: [PATCH 1/7] fix. --- .../apache/spark/sql/DataFrameReader.scala | 34 +++++------ .../datasources/jdbc/JDBCRelation.scala | 3 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 61 +++++++++++++------ 3 files changed, 63 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a77937efd7e15..8bdc44901c06e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -159,7 +159,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def jdbc(url: String, table: String, properties: Properties): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), properties) + // connectionProperties should override settings in extraOptions. + this.extraOptions = this.extraOptions ++ properties.asScala + // explicit url and dbtable should override all + this.extraOptions += ("url" -> url, "dbtable" -> table) + format("jdbc").load() } /** @@ -192,9 +196,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { upperBound: Long, numPartitions: Int, connectionProperties: Properties): DataFrame = { - val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) - val parts = JDBCRelation.columnPartition(partitioning) - jdbc(url, table, parts, connectionProperties) + // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. + this.extraOptions ++= Map( + JDBCOptions.JDBC_PARTITION_COLUMN -> columnName, + JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString, + JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString, + JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString) + jdbc(url, table, connectionProperties) } /** @@ -220,22 +228,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { table: String, predicates: Array[String], connectionProperties: Properties): DataFrame = { + // connectionProperties should override settings in extraOptions. + val params = extraOptions.toMap ++ connectionProperties.asScala.toMap + val options = new JDBCOptions(url, table, params) val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) => JDBCPartition(part, i) : Partition } - jdbc(url, table, parts, connectionProperties) - } - - private def jdbc( - url: String, - table: String, - parts: Array[Partition], - connectionProperties: Properties): DataFrame = { - // connectionProperties should override settings in extraOptions. - this.extraOptions = this.extraOptions ++ connectionProperties.asScala - // explicit url and dbtable should override all - this.extraOptions += ("url" -> url, "dbtable" -> table) - format("jdbc").load() + val relation = JDBCRelation(parts, options)(sparkSession) + sparkSession.baseRelationToDataFrame(relation) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 672c21c6ac734..00015f82ced36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -137,7 +137,8 @@ private[sql] case class JDBCRelation( } override def toString: String = { + val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" // credentials should not be included in the plan output, table information is sufficient. - s"JDBCRelation(${jdbcOptions.table})" + s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 71cf5e6a22916..74a3e2ce810ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -209,6 +209,14 @@ class JDBCSuite extends SparkFunSuite conn.close() } + // Check whether the tables are fetched in the expected degree of parallelism + def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = { + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + val plans = spark.sessionState.executePlan(explain).executedPlan + val expectedMsg = s"${JDBCOptions.JDBC_NUM_PARTITIONS}=$expectedNumPartitions" + assert(plans.executeCollect().map(_.toString).mkString.contains(expectedMsg)) + } + test("SELECT *") { assert(sql("SELECT * FROM foobar").collect().size === 3) } @@ -313,13 +321,23 @@ class JDBCSuite extends SparkFunSuite } test("SELECT * partitioned") { - assert(sql("SELECT * FROM parts").collect().size == 3) + val df = sql("SELECT * FROM parts") + checkNumPartitions(df, expectedNumPartitions = 3) + assert(df.collect().length == 3) } test("SELECT WHERE (simple predicates) partitioned") { - assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0) - assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2) - assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1) + val df1 = sql("SELECT * FROM parts WHERE THEID < 1") + checkNumPartitions(df1, expectedNumPartitions = 3) + assert(df1.collect().length === 0) + + val df2 = sql("SELECT * FROM parts WHERE THEID != 2") + checkNumPartitions(df2, expectedNumPartitions = 3) + assert(df2.collect().length === 2) + + val df3 = sql("SELECT THEID FROM parts WHERE THEID = 1") + checkNumPartitions(df3, expectedNumPartitions = 3) + assert(df3.collect().length === 1) } test("SELECT second field partitioned") { @@ -370,24 +388,27 @@ class JDBCSuite extends SparkFunSuite } test("Partitioning via JDBCPartitioningInfo API") { - assert( - spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties()) - .collect().length === 3) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties()) + checkNumPartitions(df, expectedNumPartitions = 3) + assert(df.collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties()) - .collect().length === 3) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties()) + checkNumPartitions(df, expectedNumPartitions = 2) + assert(df.collect().length === 3) } test("Partitioning on column that might have null values.") { - assert( - spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties()) - .collect().length === 4) - assert( - spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties()) - .collect().length === 4) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties()) + checkNumPartitions(df, expectedNumPartitions = 3) + assert(df.collect().length === 4) + + val df2 = spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties()) + checkNumPartitions(df2, expectedNumPartitions = 3) + assert(df2.collect().length === 4) + // partitioning on a nullable quoted column assert( spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties()) @@ -404,6 +425,7 @@ class JDBCSuite extends SparkFunSuite numPartitions = 0, connectionProperties = new Properties() ) + checkNumPartitions(res, expectedNumPartitions = 1) assert(res.count() === 8) } @@ -417,6 +439,7 @@ class JDBCSuite extends SparkFunSuite numPartitions = 10, connectionProperties = new Properties() ) + checkNumPartitions(res, expectedNumPartitions = 4) assert(res.count() === 8) } @@ -430,6 +453,7 @@ class JDBCSuite extends SparkFunSuite numPartitions = 4, connectionProperties = new Properties() ) + checkNumPartitions(res, expectedNumPartitions = 1) assert(res.count() === 8) } @@ -450,7 +474,9 @@ class JDBCSuite extends SparkFunSuite } test("SELECT * on partitioned table with a nullable partition column") { - assert(sql("SELECT * FROM nullparts").collect().size == 4) + val df = sql("SELECT * FROM nullparts") + checkNumPartitions(df, expectedNumPartitions = 3) + assert(df.collect().length == 4) } test("H2 integral types") { @@ -720,7 +746,8 @@ class JDBCSuite extends SparkFunSuite } // test the JdbcRelation toString output df.queryExecution.analyzed.collect { - case r: LogicalRelation => assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE)") + case r: LogicalRelation => + assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE) [numPartitions=3]") } } From 5c5b3cab4ec77484467e960319abbddb49313952 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 23 Nov 2016 23:24:08 -0800 Subject: [PATCH 2/7] address comments --- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 2 +- .../spark/sql/execution/datasources/jdbc/JDBCRelation.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8bdc44901c06e..5c55b54d658ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -159,7 +159,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def jdbc(url: String, table: String, properties: Properties): DataFrame = { - // connectionProperties should override settings in extraOptions. + // properties should override settings in extraOptions. this.extraOptions = this.extraOptions ++ properties.asScala // explicit url and dbtable should override all this.extraOptions += ("url" -> url, "dbtable" -> table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 00015f82ced36..30caa73adc1ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -137,7 +137,7 @@ private[sql] case class JDBCRelation( } override def toString: String = { - val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" + val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" // credentials should not be included in the plan output, table information is sufficient. s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo } From fdb73920c39b451ece5adea8543d6397128c4c4b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 25 Nov 2016 19:45:46 -0800 Subject: [PATCH 3/7] fix test failure introduced by code merge. --- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 84a878be6f0b0..20584b6b0fd78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -416,17 +416,18 @@ class JDBCSuite extends SparkFunSuite } test("Partitioning on column where numPartitions is zero") { - val res = spark.read.jdbc( - url = urlWithUserAndPass, - table = "TEST.seq", - columnName = "id", - lowerBound = 0, - upperBound = 4, - numPartitions = 0, - connectionProperties = new Properties() - ) - checkNumPartitions(res, expectedNumPartitions = 1) - assert(res.count() === 8) + val e = intercept[IllegalArgumentException] { + spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 0, + upperBound = 4, + numPartitions = 0, + connectionProperties = new Properties() + ) + }.getMessage + assert(e.contains("Invalid value `0` for parameter `numPartitions`. The minimum value is 1.")) } test("Partitioning on column where numPartitions are more than the number of total rows") { From 1b0caea20bd233ffda5113c11234d8fd57f6faa3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 25 Nov 2016 20:57:12 -0800 Subject: [PATCH 4/7] nit --- .../apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index c2a1ad84b952e..dde940d491fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -657,7 +657,7 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - options: JDBCOptions) { + options: JDBCOptions): Unit = { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType From c29199c65ad61949ca015515ef4602e39d81d6ed Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 28 Nov 2016 10:19:44 -0800 Subject: [PATCH 5/7] fix. --- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../datasources/jdbc/JDBCOptions.scala | 3 --- .../datasources/jdbc/JdbcUtils.scala | 14 +++++------ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 23 +++++++++---------- .../spark/sql/jdbc/JDBCWriteSuite.scala | 5 +++- 5 files changed, 23 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9895f7b789f65..3796986c2521c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -162,7 +162,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { // properties should override settings in extraOptions. this.extraOptions = this.extraOptions ++ properties.asScala // explicit url and dbtable should override all - this.extraOptions += ("url" -> url, "dbtable" -> table) + this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table) format("jdbc").load() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index fe2f4c1d78647..56cd17816f7bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -76,9 +76,6 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) - require(numPartitions.isEmpty || numPartitions.get > 0, - s"Invalid value `${numPartitions.get}` for parameter `$JDBC_NUM_PARTITIONS`. " + - "The minimum value is 1.") // ------------------------------------------------------------ // Optional parameters only for reading diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index dde940d491fe1..ff29a15960c57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -667,13 +667,13 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel - val numPartitions = options.numPartitions - val repartitionedDF = - if (numPartitions.isDefined && numPartitions.get < df.rdd.getNumPartitions) { - df.coalesce(numPartitions.get) - } else { - df - } + val repartitionedDF = options.numPartitions match { + case Some(n) if n <= 0 => throw new IllegalArgumentException( + s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " + + "via JDBC. The minimum value is 1.") + case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) + case _ => df + } repartitionedDF.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 20584b6b0fd78..84a878be6f0b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -416,18 +416,17 @@ class JDBCSuite extends SparkFunSuite } test("Partitioning on column where numPartitions is zero") { - val e = intercept[IllegalArgumentException] { - spark.read.jdbc( - url = urlWithUserAndPass, - table = "TEST.seq", - columnName = "id", - lowerBound = 0, - upperBound = 4, - numPartitions = 0, - connectionProperties = new Properties() - ) - }.getMessage - assert(e.contains("Invalid value `0` for parameter `numPartitions`. The minimum value is 1.")) + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 0, + upperBound = 4, + numPartitions = 0, + connectionProperties = new Properties() + ) + checkNumPartitions(res, expectedNumPartitions = 1) + assert(res.count() === 8) } test("Partitioning on column where numPartitions are more than the number of total rows") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index c834419948c53..064958bc1a14c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -319,9 +319,12 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { df.write.format("jdbc") .option("dbtable", "TEST.SAVETEST") .option("url", url1) + .option("user", "testUser") + .option("password", "testPass") .option(s"${JDBCOptions.JDBC_NUM_PARTITIONS}", "0") .save() }.getMessage - assert(e.contains("Invalid value `0` for parameter `numPartitions`. The minimum value is 1")) + assert(e.contains("Invalid value `0` for parameter `numPartitions` in table writing " + + "via JDBC. The minimum value is 1.")) } } From 404aa223dc66419319de15798bf43abe00fd6e64 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 29 Nov 2016 23:47:27 -0800 Subject: [PATCH 6/7] address comments. --- .../org/apache/spark/sql/DataFrameReader.scala | 3 ++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d6ca1ce9e8dea..365b50dee93c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -181,7 +181,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param upperBound the maximum value of `columnName` used to decide partition stride. * @param numPartitions the number of partitions. This, along with `lowerBound` (inclusive), * `upperBound` (exclusive), form partition strides for generated WHERE - * clause expressions used to split the column `columnName` evenly. + * clause expressions used to split the column `columnName` evenly. When + * the input is less than 1, the number is set to 1. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8176aa72702d5..da105d175f569 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,12 +24,12 @@ import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -211,10 +211,12 @@ class JDBCSuite extends SparkFunSuite // Check whether the tables are fetched in the expected degree of parallelism def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = { - val explain = ExplainCommand(df.queryExecution.logical, extended = true) - val plans = spark.sessionState.executePlan(explain).executedPlan - val expectedMsg = s"${JDBCOptions.JDBC_NUM_PARTITIONS}=$expectedNumPartitions" - assert(plans.executeCollect().map(_.toString).mkString.contains(expectedMsg)) + df.queryExecution.analyzed.collectFirst { + case LogicalRelation(r: JDBCRelation, _, _) if r.parts.length == expectedNumPartitions => () + }.getOrElse { + fail(s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:\n" + + s"${df.queryExecution.analyzed}") + } } test("SELECT *") { From 728c103fc10d5118eff4ff5bf9372da8557ecf60 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 30 Nov 2016 13:21:38 -0800 Subject: [PATCH 7/7] address comments. --- .../scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index da105d175f569..218ccf9332cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -211,12 +211,12 @@ class JDBCSuite extends SparkFunSuite // Check whether the tables are fetched in the expected degree of parallelism def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = { - df.queryExecution.analyzed.collectFirst { - case LogicalRelation(r: JDBCRelation, _, _) if r.parts.length == expectedNumPartitions => () - }.getOrElse { - fail(s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:\n" + - s"${df.queryExecution.analyzed}") + val jdbcRelations = df.queryExecution.analyzed.collect { + case LogicalRelation(r: JDBCRelation, _, _) => r } + assert(jdbcRelations.length == 1) + assert(jdbcRelations.head.parts.length == expectedNumPartitions, + s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:`$jdbcRelations`") } test("SELECT *") {