diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0e38f224ac81d..691c705bb0b4f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -35,14 +35,18 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e * @param getConnection a function that returns an open Connection. * The RDD takes care of closing the connection. * @param sql the text of the query. - * The query must contain two ? placeholders for parameters used to partition the results. + * The query must contain two ? placeholders for parameters used to partition the results, + * when you wan to use more than one partitions. * E.g. "select title, author from books where ? <= id and id <= ?" + * If numPartitions is set to exactly 1, the query do not need to contain any ? placeholder. * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. + * If query do not contain any ? placeholder, lowerBound and upperBound can be set to any value. * @param numPartitions the number of partitions. * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, * the query would be executed twice, once with (1, 10) and once with (11, 20) + * If query do not contain any ? placeholder, numPartitions must be set to exactly 1. * @param mapRow a function from a ResultSet to a single row of the desired result type(s). * This should only call getInt, getString, etc; the RDD takes care of calling next. * The default maps a ResultSet to an array of Object. @@ -57,6 +61,8 @@ class JdbcRDD[T: ClassTag]( mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) extends RDD[T](sc, Nil) with Logging { + private var schema: Seq[(String, Int, Boolean)] = null + override def getPartitions: Array[Partition] = { // bounds are inclusive, hence the + 1 here and - 1 on end val length = 1 + upperBound - lowerBound @@ -67,6 +73,32 @@ class JdbcRDD[T: ClassTag]( }).toArray } + def getSchema: Seq[(String, Int, Boolean)] = { + if (null != schema) { + return schema + } + + val conn = getConnection() + val stmt = conn.prepareStatement(sql) + val metadata = stmt.getMetaData + try { + if (null != stmt && ! stmt.isClosed()) { + stmt.close() + } + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + schema = Seq[(String, Int, Boolean)]() + for(i <- 1 to metadata.getColumnCount) { + schema :+= ( + metadata.getColumnName(i), + metadata.getColumnType(i), + metadata.isNullable(i) == java.sql.ResultSetMetaData.columnNullable + ) + } + schema + } + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] @@ -81,8 +113,14 @@ class JdbcRDD[T: ClassTag]( logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") } - stmt.setLong(1, part.lower) - stmt.setLong(2, part.upper) + val parameterCount = stmt.getParameterMetaData.getParameterCount + if (parameterCount > 0) { + stmt.setLong(1, part.lower) + } + if (parameterCount > 1) { + stmt.setLong(2, part.upper) + } + val rs = stmt.executeQuery() override def getNext: T = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c8016e41256d5..c5e6db3190bb1 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,11 @@ junit test + + org.apache.derby + derby + test + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a75af94d29303..c6310aad140b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.JdbcRDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions @@ -35,8 +36,10 @@ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ +import org.apache.spark.sql.jdbc._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.{Logging, SparkContext} +import java.sql.{DriverManager, ResultSet} /** * :: AlphaComponent :: @@ -204,6 +207,54 @@ class SQLContext(@transient val sparkContext: SparkContext) applySchema(rowRDD, appliedSchema) } + /** + * Loads from JDBC, returning the ResultSet as a [[SchemaRDD]]. + * It gets MetaData from ResultSet of PreparedStatement to determine the schema. + * + * @group userf + */ + def jdbcResultSet( + connectString: String, + sql: String): SchemaRDD = { + jdbcResultSet(connectString, "", "", sql, 0, 0, 1) + } + + def jdbcResultSet( + connectString: String, + username: String, + password: String, + sql: String): SchemaRDD = { + jdbcResultSet(connectString, username, password, sql, 0, 0, 1) + } + + def jdbcResultSet( + connectString: String, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): SchemaRDD = { + jdbcResultSet(connectString, "", "", sql, lowerBound, upperBound, numPartitions) + } + + def jdbcResultSet( + connectString: String, + username: String, + password: String, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): SchemaRDD = { + val resultSetRDD = new JdbcRDD( + sparkContext, + () => { DriverManager.getConnection(connectString, username, password) }, + sql, lowerBound, upperBound, numPartitions, + (r: ResultSet) => r + ) + val appliedSchema = JdbcResultSetRDD.inferSchema(resultSetRDD) + val rowRDD = JdbcResultSetRDD.jdbcResultSetToRow(resultSetRDD, appliedSchema) + applySchema(rowRDD, appliedSchema) + } + /** * :: Experimental :: * Creates an empty parquet file with the schema of class `A`, which can be registered as a table. @@ -411,7 +462,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } - def simpleString: String = + def simpleString: String = s"""== Physical Plan == |${stringOrError(executedPlan)} """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDD.scala new file mode 100644 index 0000000000000..5910846084e0c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDD.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.ResultSet + +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.JdbcRDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.Logging + +private[sql] object JdbcResultSetRDD extends Logging { + + private[sql] def inferSchema( + jdbcResultSet: JdbcRDD[ResultSet]): StructType = { + StructType(createSchema(jdbcResultSet.getSchema)) + } + + private def createSchema(metaSchema: Seq[(String, Int, Boolean)]): Seq[StructField] = { + metaSchema.map(e => StructField(e._1, JdbcTypes.toPrimitiveDataType(e._2), e._3)) + } + + private[sql] def jdbcResultSetToRow( + jdbcResultSet: JdbcRDD[ResultSet], + schema: StructType) : RDD[Row] = { + val row = new GenericMutableRow(schema.fields.length) + jdbcResultSet.map(asRow(_, row, schema.fields)) + } + + private def asRow(rs: ResultSet, row: GenericMutableRow, schemaFields: Seq[StructField]): Row = { + var i = 0 + while (i < schemaFields.length) { + schemaFields(i).dataType match { + case StringType => row.update(i, rs.getString(i + 1)) + case DecimalType => row.update(i, rs.getBigDecimal(i + 1)) + case BooleanType => row.update(i, rs.getBoolean(i + 1)) + case ByteType => row.update(i, rs.getByte(i + 1)) + case ShortType => row.update(i, rs.getShort(i + 1)) + case IntegerType => row.update(i, rs.getInt(i + 1)) + case LongType => row.update(i, rs.getLong(i + 1)) + case FloatType => row.update(i, rs.getFloat(i + 1)) + case DoubleType => row.update(i, rs.getDouble(i + 1)) + case BinaryType => row.update(i, rs.getBytes(i + 1)) + case TimestampType => row.update(i, rs.getTimestamp(i + 1)) + case _ => sys.error( + s"Unsupported jdbc datatype") + } + if (rs.wasNull) row.update(i, null) + i += 1 + } + + row + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcTypes.scala new file mode 100644 index 0000000000000..ada4709d69625 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcTypes.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types._ + +private[sql] object JdbcTypes extends Logging { + + /** + * More about JDBC types mapped to Java types: + * http://docs.oracle.com/javase/6/docs/technotes/guides/jdbc/getstart/mapping.html#1051555 + * + * Compatibility of ResultSet getter Methods defined in JDBC spec: + * http://download.oracle.com/otn-pub/jcp/jdbc-4_1-mrel-spec/jdbc4.1-fr-spec.pdf + * page 211 + */ + def toPrimitiveDataType(jdbcType: Int): DataType = + jdbcType match { + case java.sql.Types.LONGVARCHAR + | java.sql.Types.VARCHAR + | java.sql.Types.CHAR => StringType + case java.sql.Types.NUMERIC + | java.sql.Types.DECIMAL => DecimalType + case java.sql.Types.BIT => BooleanType + case java.sql.Types.TINYINT => ByteType + case java.sql.Types.SMALLINT => ShortType + case java.sql.Types.INTEGER => IntegerType + case java.sql.Types.BIGINT => LongType + case java.sql.Types.REAL => FloatType + case java.sql.Types.FLOAT + | java.sql.Types.DOUBLE => DoubleType + case java.sql.Types.LONGVARBINARY + | java.sql.Types.VARBINARY + | java.sql.Types.BINARY => BinaryType + // Timestamp's getter should also be able to get DATE and TIME according to JDBC spec + case java.sql.Types.TIMESTAMP + | java.sql.Types.DATE + | java.sql.Types.TIME => TimestampType + + // TODO: CLOB only works with getClob or getAscIIStream + // case java.sql.Types.CLOB + + // TODO: BLOB only works with getBlob or getBinaryStream + // case java.sql.Types.BLOB + + // TODO: nested types + // case java.sql.Types.ARRAY => ArrayType + // case java.sql.Types.STRUCT => StructType + + // TODO: unsupported types + // case java.sql.Types.DISTINCT + // case java.sql.Types.REF + + // TODO: more about JAVA_OBJECT: + // http://docs.oracle.com/javase/6/docs/technotes/guides/jdbc/getstart/mapping.html#1038181 + // case java.sql.Types.JAVA_OBJECT => BinaryType + + case _ => sys.error( + s"Unsupported jdbc datatype") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDDSuite.scala new file mode 100644 index 0000000000000..354d7265dc2d6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JdbcResultSetRDDSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql._ + +import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.TestSQLContext._ + +class JdbcResultSetRDDSuite extends QueryTest with BeforeAndAfter { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcSchemaRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close() + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close() + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close() + } + } + + test("basic functionality") { + val jdbcResultSetRDD = jdbcResultSet("jdbc:derby:target/JdbcSchemaRDDSuiteDb", "SELECT DATA FROM FOO") + jdbcResultSetRDD.registerTempTable("foo") + + checkAnswer( + sql("select count(*) from foo"), + 100 + ) + checkAnswer( + sql("select sum(DATA) from foo"), + 10100 + ) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +}