diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 0e38f224ac81d..691c705bb0b4f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -35,14 +35,18 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e
* @param getConnection a function that returns an open Connection.
* The RDD takes care of closing the connection.
* @param sql the text of the query.
- * The query must contain two ? placeholders for parameters used to partition the results.
+ * The query must contain two ? placeholders for parameters used to partition the results,
+ * when you wan to use more than one partitions.
* E.g. "select title, author from books where ? <= id and id <= ?"
+ * If numPartitions is set to exactly 1, the query do not need to contain any ? placeholder.
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
+ * If query do not contain any ? placeholder, lowerBound and upperBound can be set to any value.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * If query do not contain any ? placeholder, numPartitions must be set to exactly 1.
* @param mapRow a function from a ResultSet to a single row of the desired result type(s).
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
@@ -57,6 +61,8 @@ class JdbcRDD[T: ClassTag](
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
+ private var schema: Seq[(String, Int, Boolean)] = null
+
override def getPartitions: Array[Partition] = {
// bounds are inclusive, hence the + 1 here and - 1 on end
val length = 1 + upperBound - lowerBound
@@ -67,6 +73,32 @@ class JdbcRDD[T: ClassTag](
}).toArray
}
+ def getSchema: Seq[(String, Int, Boolean)] = {
+ if (null != schema) {
+ return schema
+ }
+
+ val conn = getConnection()
+ val stmt = conn.prepareStatement(sql)
+ val metadata = stmt.getMetaData
+ try {
+ if (null != stmt && ! stmt.isClosed()) {
+ stmt.close()
+ }
+ } catch {
+ case e: Exception => logWarning("Exception closing statement", e)
+ }
+ schema = Seq[(String, Int, Boolean)]()
+ for(i <- 1 to metadata.getColumnCount) {
+ schema :+= (
+ metadata.getColumnName(i),
+ metadata.getColumnType(i),
+ metadata.isNullable(i) == java.sql.ResultSetMetaData.columnNullable
+ )
+ }
+ schema
+ }
+
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
@@ -81,8 +113,14 @@ class JdbcRDD[T: ClassTag](
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
}
- stmt.setLong(1, part.lower)
- stmt.setLong(2, part.upper)
+ val parameterCount = stmt.getParameterMetaData.getParameterCount
+ if (parameterCount > 0) {
+ stmt.setLong(1, part.lower)
+ }
+ if (parameterCount > 1) {
+ stmt.setLong(2, part.upper)
+ }
+
val rs = stmt.executeQuery()
override def getNext: T = {
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index c8016e41256d5..c5e6db3190bb1 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -73,6 +73,11 @@
junit
test
+
+ org.apache.derby
+ derby
+ test
+
org.scalatest
scalatest_${scala.binary.version}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a75af94d29303..c6310aad140b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.JdbcRDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.ExpressionConversions
@@ -35,8 +36,10 @@ import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.json._
+import org.apache.spark.sql.jdbc._
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.{Logging, SparkContext}
+import java.sql.{DriverManager, ResultSet}
/**
* :: AlphaComponent ::
@@ -204,6 +207,54 @@ class SQLContext(@transient val sparkContext: SparkContext)
applySchema(rowRDD, appliedSchema)
}
+ /**
+ * Loads from JDBC, returning the ResultSet as a [[SchemaRDD]].
+ * It gets MetaData from ResultSet of PreparedStatement to determine the schema.
+ *
+ * @group userf
+ */
+ def jdbcResultSet(
+ connectString: String,
+ sql: String): SchemaRDD = {
+ jdbcResultSet(connectString, "", "", sql, 0, 0, 1)
+ }
+
+ def jdbcResultSet(
+ connectString: String,
+ username: String,
+ password: String,
+ sql: String): SchemaRDD = {
+ jdbcResultSet(connectString, username, password, sql, 0, 0, 1)
+ }
+
+ def jdbcResultSet(
+ connectString: String,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): SchemaRDD = {
+ jdbcResultSet(connectString, "", "", sql, lowerBound, upperBound, numPartitions)
+ }
+
+ def jdbcResultSet(
+ connectString: String,
+ username: String,
+ password: String,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): SchemaRDD = {
+ val resultSetRDD = new JdbcRDD(
+ sparkContext,
+ () => { DriverManager.getConnection(connectString, username, password) },
+ sql, lowerBound, upperBound, numPartitions,
+ (r: ResultSet) => r
+ )
+ val appliedSchema = JdbcResultSetRDD.inferSchema(resultSetRDD)
+ val rowRDD = JdbcResultSetRDD.jdbcResultSetToRow(resultSetRDD, appliedSchema)
+ applySchema(rowRDD, appliedSchema)
+ }
+
/**
* :: Experimental ::
* Creates an empty parquet file with the schema of class `A`, which can be registered as a table.
@@ -411,7 +462,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
- def simpleString: String =
+ def simpleString: String =
s"""== Physical Plan ==
|${stringOrError(executedPlan)}
"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDD.scala
new file mode 100644
index 0000000000000..5910846084e0c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDD.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.ResultSet
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.JdbcRDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+import org.apache.spark.Logging
+
+private[sql] object JdbcResultSetRDD extends Logging {
+
+ private[sql] def inferSchema(
+ jdbcResultSet: JdbcRDD[ResultSet]): StructType = {
+ StructType(createSchema(jdbcResultSet.getSchema))
+ }
+
+ private def createSchema(metaSchema: Seq[(String, Int, Boolean)]): Seq[StructField] = {
+ metaSchema.map(e => StructField(e._1, JdbcTypes.toPrimitiveDataType(e._2), e._3))
+ }
+
+ private[sql] def jdbcResultSetToRow(
+ jdbcResultSet: JdbcRDD[ResultSet],
+ schema: StructType) : RDD[Row] = {
+ val row = new GenericMutableRow(schema.fields.length)
+ jdbcResultSet.map(asRow(_, row, schema.fields))
+ }
+
+ private def asRow(rs: ResultSet, row: GenericMutableRow, schemaFields: Seq[StructField]): Row = {
+ var i = 0
+ while (i < schemaFields.length) {
+ schemaFields(i).dataType match {
+ case StringType => row.update(i, rs.getString(i + 1))
+ case DecimalType => row.update(i, rs.getBigDecimal(i + 1))
+ case BooleanType => row.update(i, rs.getBoolean(i + 1))
+ case ByteType => row.update(i, rs.getByte(i + 1))
+ case ShortType => row.update(i, rs.getShort(i + 1))
+ case IntegerType => row.update(i, rs.getInt(i + 1))
+ case LongType => row.update(i, rs.getLong(i + 1))
+ case FloatType => row.update(i, rs.getFloat(i + 1))
+ case DoubleType => row.update(i, rs.getDouble(i + 1))
+ case BinaryType => row.update(i, rs.getBytes(i + 1))
+ case TimestampType => row.update(i, rs.getTimestamp(i + 1))
+ case _ => sys.error(
+ s"Unsupported jdbc datatype")
+ }
+ if (rs.wasNull) row.update(i, null)
+ i += 1
+ }
+
+ row
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcTypes.scala
new file mode 100644
index 0000000000000..ada4709d69625
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcTypes.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.types._
+
+private[sql] object JdbcTypes extends Logging {
+
+ /**
+ * More about JDBC types mapped to Java types:
+ * http://docs.oracle.com/javase/6/docs/technotes/guides/jdbc/getstart/mapping.html#1051555
+ *
+ * Compatibility of ResultSet getter Methods defined in JDBC spec:
+ * http://download.oracle.com/otn-pub/jcp/jdbc-4_1-mrel-spec/jdbc4.1-fr-spec.pdf
+ * page 211
+ */
+ def toPrimitiveDataType(jdbcType: Int): DataType =
+ jdbcType match {
+ case java.sql.Types.LONGVARCHAR
+ | java.sql.Types.VARCHAR
+ | java.sql.Types.CHAR => StringType
+ case java.sql.Types.NUMERIC
+ | java.sql.Types.DECIMAL => DecimalType
+ case java.sql.Types.BIT => BooleanType
+ case java.sql.Types.TINYINT => ByteType
+ case java.sql.Types.SMALLINT => ShortType
+ case java.sql.Types.INTEGER => IntegerType
+ case java.sql.Types.BIGINT => LongType
+ case java.sql.Types.REAL => FloatType
+ case java.sql.Types.FLOAT
+ | java.sql.Types.DOUBLE => DoubleType
+ case java.sql.Types.LONGVARBINARY
+ | java.sql.Types.VARBINARY
+ | java.sql.Types.BINARY => BinaryType
+ // Timestamp's getter should also be able to get DATE and TIME according to JDBC spec
+ case java.sql.Types.TIMESTAMP
+ | java.sql.Types.DATE
+ | java.sql.Types.TIME => TimestampType
+
+ // TODO: CLOB only works with getClob or getAscIIStream
+ // case java.sql.Types.CLOB
+
+ // TODO: BLOB only works with getBlob or getBinaryStream
+ // case java.sql.Types.BLOB
+
+ // TODO: nested types
+ // case java.sql.Types.ARRAY => ArrayType
+ // case java.sql.Types.STRUCT => StructType
+
+ // TODO: unsupported types
+ // case java.sql.Types.DISTINCT
+ // case java.sql.Types.REF
+
+ // TODO: more about JAVA_OBJECT:
+ // http://docs.oracle.com/javase/6/docs/technotes/guides/jdbc/getstart/mapping.html#1038181
+ // case java.sql.Types.JAVA_OBJECT => BinaryType
+
+ case _ => sys.error(
+ s"Unsupported jdbc datatype")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDDSuite.scala
new file mode 100644
index 0000000000000..354d7265dc2d6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDDSuite.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql._
+
+import org.scalatest.BeforeAndAfter
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.test.TestSQLContext._
+
+class JdbcResultSetRDDSuite extends QueryTest with BeforeAndAfter {
+
+ before {
+ Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
+ val conn = DriverManager.getConnection("jdbc:derby:target/JdbcSchemaRDDSuiteDb;create=true")
+ try {
+ val create = conn.createStatement
+ create.execute("""
+ CREATE TABLE FOO(
+ ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
+ DATA INTEGER
+ )""")
+ create.close()
+ val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
+ (1 to 100).foreach { i =>
+ insert.setInt(1, i * 2)
+ insert.executeUpdate
+ }
+ insert.close()
+ } catch {
+ case e: SQLException if e.getSQLState == "X0Y32" =>
+ // table exists
+ } finally {
+ conn.close()
+ }
+ }
+
+ test("basic functionality") {
+ val jdbcResultSetRDD = jdbcResultSet("jdbc:derby:target/JdbcSchemaRDDSuiteDb", "SELECT DATA FROM FOO")
+ jdbcResultSetRDD.registerTempTable("foo")
+
+ checkAnswer(
+ sql("select count(*) from foo"),
+ 100
+ )
+ checkAnswer(
+ sql("select sum(DATA) from foo"),
+ 10100
+ )
+ }
+
+ after {
+ try {
+ DriverManager.getConnection("jdbc:derby:;shutdown=true")
+ } catch {
+ case se: SQLException if se.getSQLState == "XJ015" =>
+ // normal shutdown
+ }
+ }
+}