diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 32f1100406d74..11ab81f1498ba 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -25,6 +25,8 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import net.razorvine.pickle.{Pickler, Unpickler}
+
import org.apache.spark._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
@@ -206,7 +208,7 @@ private object SpecialLengths {
val TIMING_DATA = -3
}
-private[spark] object PythonRDD {
+object PythonRDD {
val UTF8 = Charset.forName("UTF-8")
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -284,6 +286,42 @@ private[spark] object PythonRDD {
file.close()
}
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[_] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ val unpickle = new Unpickler
+ // TODO: Figure out why flatMap is necessay for pyspark
+ iter.flatMap { row =>
+ unpickle.loads(row) match {
+ case objs: java.util.ArrayList[Any] => objs
+ // Incase the partition doesn't have a collection
+ case obj => Seq(obj)
+ }
+ }
+ }
+ }
+
+ def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ val unpickle = new Unpickler
+ // TODO: Figure out why flatMap is necessay for pyspark
+ iter.flatMap { row =>
+ unpickle.loads(row) match {
+ case objs: java.util.ArrayList[JMap[String, _]] => objs.map(_.toMap)
+ // Incase the partition doesn't have a collection
+ case obj: JMap[String, _] => Seq(obj.toMap)
+ }
+ }
+ }
+ }
+
+ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter =>
+ val unpickle = new Pickler
+ iter.map { row =>
+ unpickle.dumps(row)
+ }
+ }
+ }
}
private
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 843a874fbfdb0..60d34ab2fb0f2 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -328,7 +328,8 @@ object SparkBuild extends Build {
"com.twitter" %% "chill" % "0.3.1" excludeAll(excludeAsm),
"com.twitter" % "chill-java" % "0.3.1" excludeAll(excludeAsm),
"org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock),
- "com.clearspring.analytics" % "stream" % "2.5.1"
+ "com.clearspring.analytics" % "stream" % "2.5.1",
+ "net.razorvine" % "pyrolite_2.10" % "1.1"
),
libraryDependencies ++= maybeAvro
)
@@ -506,6 +507,7 @@ object SparkBuild extends Build {
def extraAssemblySettings() = Seq(
test in assembly := {},
+ assemblyOption in assembly ~= { _.copy(cacheOutput = false) },
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d8667e84fedff..f30ebb9c8e7d9 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -32,7 +32,7 @@
PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, SchemaRDD
from py4j.java_collections import ListConverter
@@ -174,6 +174,9 @@ def _ensure_initialized(cls, instance=None, gateway=None):
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
+ SparkContext._pythonToJava = SparkContext._jvm.PythonRDD.pythonToJava
+ SparkContext._pythonToJavaMap = SparkContext._jvm.PythonRDD.pythonToJavaMap
+ SparkContext._javaToPython = SparkContext._jvm.PythonRDD.javaToPython
if instance:
if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
@@ -460,6 +463,28 @@ def sparkUser(self):
"""
return self._jsc.sc().sparkUser()
+class SQLContext:
+
+ def __init__(self, sparkContext):
+ self._sc = sparkContext
+ self._jsc = self._sc._jsc
+ self._jvm = self._sc._jvm
+ self._ssql_ctx = self._jvm.SQLContext(self._jsc.sc())
+
+ def sql(self, sqlQuery):
+ return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+
+ def applySchema(self, rdd):
+ if (rdd.__class__ is SchemaRDD):
+ raise Exception("Cannot apply schema to %s" % SchemaRDD.__name__)
+ elif type(rdd.first()) is not dict:
+ raise Exception("Only RDDs with dictionaries can be converted to %s" % SchemaRDD.__name__)
+
+ jrdd = self._sc._pythonToJavaMap(rdd._jrdd)
+ srdd = self._ssql_ctx.applySchema(jrdd.rdd())
+ return SchemaRDD(srdd, self)
+
+
def _test():
import atexit
import doctest
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 6a16756e0576d..d8dd2a65225e1 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -64,5 +64,6 @@ def run(self):
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index fb27863e07f55..e11300212c4b1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1387,6 +1387,42 @@ def _jrdd(self):
def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)
+class Row(dict):
+
+ def __init__(self, d):
+ d.update(self.__dict__)
+ self.__dict__ = d
+ dict.__init__(self, d)
+
+class SchemaRDD(RDD):
+
+ def __init__(self, jschema_rdd, sql_ctx):
+ self.sql_ctx = sql_ctx
+ self._sc = sql_ctx._sc
+ self._jschema_rdd = jschema_rdd
+
+ self.is_cached = False
+ self.is_checkpointed = False
+ self.ctx = self.sql_ctx._sc
+ self._jrdd_deserializer = self.ctx.serializer
+
+ @property
+ def _jrdd(self):
+ return self.toPython()._jrdd
+
+ @property
+ def _id(self):
+ return self._jrdd.id()
+
+ def registerAsTable(self, name):
+ self._jschema_rdd.registerAsTable(name)
+
+ def toPython(self):
+ jrdd = self._jschema_rdd.javaToPython()
+ # TODO: This is inefficient, we should construct the Python Row object
+ # in Java land in the javaToPython function. May require a custom
+ # pickle serializer in Pyrolite
+ return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d))
def _test():
import doctest
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 36059c6630aa4..ef23cf1739246 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -25,11 +25,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.execution._
+import org.apache.spark.api.java.JavaRDD
/**
* ALPHA COMPONENT
@@ -238,4 +240,29 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def debugExec() = DebugQuery(executedPlan).execute().collect()
}
+
+ def applySchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
+ val schema = rdd.first.map { case (fieldName, obj) =>
+ val dataType = obj.getClass match {
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ // case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+ // case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+ // case c: Class[_] if c == java.lang.Long.TYPE => LongType
+ // case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+ // case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+ // case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+ // case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+ }
+ AttributeReference(fieldName, dataType, true)()
+ }.toSeq
+
+ val rowRdd = rdd.mapPartitions { iter =>
+ iter.map { map =>
+ new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
+ }
+ }
+ new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index fc95781448569..5dec32ef418df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import net.razorvine.pickle.{Pickler, Unpickler}
+
import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
@@ -24,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.types.BooleanType
+import org.apache.spark.api.java.JavaRDD
+import java.util.{Map => JMap}
/**
* ALPHA COMPONENT
@@ -307,4 +311,20 @@ class SchemaRDD(
/** FOR INTERNAL USE ONLY */
def analyze = sqlContext.analyzer(logicalPlan)
+
+ def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldNames: Seq[String] = this.queryExecution.analyzed.output.map(_.name)
+ this.mapPartitions { iter =>
+ val pickle = new Pickler
+ iter.map { row =>
+ val map: JMap[String, Any] = new java.util.HashMap
+ val arr: java.util.ArrayList[Any] = new java.util.ArrayList
+ row.zip(fieldNames).foreach { case (obj, name) =>
+ map.put(name, obj)
+ }
+ arr.add(map)
+ pickle.dumps(arr)
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 573345e42c43c..4ca4505fbfc5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.api.java
import java.beans.{Introspector, PropertyDescriptor}
+import java.util.{Map => JMap}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.SQLContext
@@ -82,7 +83,6 @@ class JavaSQLContext(sparkContext: JavaSparkContext) {
new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
}
-
/**
* Loads a parquet file, returning the result as a [[JavaSchemaRDD]].
*/