diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 32f1100406d74..11ab81f1498ba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -25,6 +25,8 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio import scala.collection.JavaConversions._ import scala.reflect.ClassTag +import net.razorvine.pickle.{Pickler, Unpickler} + import org.apache.spark._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast @@ -206,7 +208,7 @@ private object SpecialLengths { val TIMING_DATA = -3 } -private[spark] object PythonRDD { +object PythonRDD { val UTF8 = Charset.forName("UTF-8") def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -284,6 +286,42 @@ private[spark] object PythonRDD { file.close() } + def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[_] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + // TODO: Figure out why flatMap is necessay for pyspark + iter.flatMap { row => + unpickle.loads(row) match { + case objs: java.util.ArrayList[Any] => objs + // Incase the partition doesn't have a collection + case obj => Seq(obj) + } + } + } + } + + def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { + pyRDD.rdd.mapPartitions { iter => + val unpickle = new Unpickler + // TODO: Figure out why flatMap is necessay for pyspark + iter.flatMap { row => + unpickle.loads(row) match { + case objs: java.util.ArrayList[JMap[String, _]] => objs.map(_.toMap) + // Incase the partition doesn't have a collection + case obj: JMap[String, _] => Seq(obj.toMap) + } + } + } + } + + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => + val unpickle = new Pickler + iter.map { row => + unpickle.dumps(row) + } + } + } } private diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 843a874fbfdb0..60d34ab2fb0f2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -328,7 +328,8 @@ object SparkBuild extends Build { "com.twitter" %% "chill" % "0.3.1" excludeAll(excludeAsm), "com.twitter" % "chill-java" % "0.3.1" excludeAll(excludeAsm), "org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock), - "com.clearspring.analytics" % "stream" % "2.5.1" + "com.clearspring.analytics" % "stream" % "2.5.1", + "net.razorvine" % "pyrolite_2.10" % "1.1" ), libraryDependencies ++= maybeAvro ) @@ -506,6 +507,7 @@ object SparkBuild extends Build { def extraAssemblySettings() = Seq( test in assembly := {}, + assemblyOption in assembly ~= { _.copy(cacheOutput = false) }, mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard diff --git a/python/pyspark/context.py b/python/pyspark/context.py index d8667e84fedff..f30ebb9c8e7d9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -32,7 +32,7 @@ PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark import rdd -from pyspark.rdd import RDD +from pyspark.rdd import RDD, SchemaRDD from py4j.java_collections import ListConverter @@ -174,6 +174,9 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile + SparkContext._pythonToJava = SparkContext._jvm.PythonRDD.pythonToJava + SparkContext._pythonToJavaMap = SparkContext._jvm.PythonRDD.pythonToJavaMap + SparkContext._javaToPython = SparkContext._jvm.PythonRDD.javaToPython if instance: if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: @@ -460,6 +463,28 @@ def sparkUser(self): """ return self._jsc.sc().sparkUser() +class SQLContext: + + def __init__(self, sparkContext): + self._sc = sparkContext + self._jsc = self._sc._jsc + self._jvm = self._sc._jvm + self._ssql_ctx = self._jvm.SQLContext(self._jsc.sc()) + + def sql(self, sqlQuery): + return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + + def applySchema(self, rdd): + if (rdd.__class__ is SchemaRDD): + raise Exception("Cannot apply schema to %s" % SchemaRDD.__name__) + elif type(rdd.first()) is not dict: + raise Exception("Only RDDs with dictionaries can be converted to %s" % SchemaRDD.__name__) + + jrdd = self._sc._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.applySchema(jrdd.rdd()) + return SchemaRDD(srdd, self) + + def _test(): import atexit import doctest diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 6a16756e0576d..d8dd2a65225e1 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -64,5 +64,6 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") + java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fb27863e07f55..e11300212c4b1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1387,6 +1387,42 @@ def _jrdd(self): def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) +class Row(dict): + + def __init__(self, d): + d.update(self.__dict__) + self.__dict__ = d + dict.__init__(self, d) + +class SchemaRDD(RDD): + + def __init__(self, jschema_rdd, sql_ctx): + self.sql_ctx = sql_ctx + self._sc = sql_ctx._sc + self._jschema_rdd = jschema_rdd + + self.is_cached = False + self.is_checkpointed = False + self.ctx = self.sql_ctx._sc + self._jrdd_deserializer = self.ctx.serializer + + @property + def _jrdd(self): + return self.toPython()._jrdd + + @property + def _id(self): + return self._jrdd.id() + + def registerAsTable(self, name): + self._jschema_rdd.registerAsTable(name) + + def toPython(self): + jrdd = self._jschema_rdd.javaToPython() + # TODO: This is inefficient, we should construct the Python Row object + # in Java land in the javaToPython function. May require a custom + # pickle serializer in Pyrolite + return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d)) def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 36059c6630aa4..ef23cf1739246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,11 +25,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution._ +import org.apache.spark.api.java.JavaRDD /** * ALPHA COMPONENT @@ -238,4 +240,29 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def debugExec() = DebugQuery(executedPlan).execute().collect() } + + def applySchema(rdd: RDD[Map[String, _]]): SchemaRDD = { + val schema = rdd.first.map { case (fieldName, obj) => + val dataType = obj.getClass match { + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + // case c: Class[_] if c == java.lang.Short.TYPE => ShortType + // case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + // case c: Class[_] if c == java.lang.Long.TYPE => LongType + // case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + // case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + // case c: Class[_] if c == java.lang.Float.TYPE => FloatType + // case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + } + AttributeReference(fieldName, dataType, true)() + }.toSeq + + val rowRdd = rdd.mapPartitions { iter => + iter.map { map => + new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + } + } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index fc95781448569..5dec32ef418df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import net.razorvine.pickle.{Pickler, Unpickler} + import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ @@ -24,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.api.java.JavaRDD +import java.util.{Map => JMap} /** * ALPHA COMPONENT @@ -307,4 +311,20 @@ class SchemaRDD( /** FOR INTERNAL USE ONLY */ def analyze = sqlContext.analyzer(logicalPlan) + + def javaToPython: JavaRDD[Array[Byte]] = { + val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name) + this.mapPartitions { iter => + val pickle = new Pickler + iter.map { row => + val map: JMap[String, Any] = new java.util.HashMap + val arr: java.util.ArrayList[Any] = new java.util.ArrayList + row.zip(fieldNames).foreach { case (obj, name) => + map.put(name, obj) + } + arr.add(map) + pickle.dumps(arr) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 573345e42c43c..4ca4505fbfc5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.api.java import java.beans.{Introspector, PropertyDescriptor} +import java.util.{Map => JMap} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.sql.SQLContext @@ -82,7 +83,6 @@ class JavaSQLContext(sparkContext: JavaSparkContext) { new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } - /** * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */