diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0d8453fb184a3..9547080a567b9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -22,6 +22,7 @@ import java.net._ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import scala.annotation.tailrec import scala.collection.JavaConversions._ import scala.reflect.ClassTag import scala.util.Try @@ -270,9 +271,10 @@ private object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 + val NULL = -4 } -private[spark] object PythonRDD extends Logging { +private[spark] object PythonRDD { val UTF8 = Charset.forName("UTF-8") /** @@ -312,42 +314,51 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + @tailrec def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the // entire Spark API, I have to use this hacky approach: + def writeBytes(bytes: Array[Byte]) { + if (bytes == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + } if (iter.hasNext) { val first = iter.next() val newIter = Seq(first).iterator ++ iter first match { case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { writeBytes(_) } case string: String => - newIter.asInstanceOf[Iterator[String]].foreach { str => - writeUTF(str, dataOut) - } + newIter.asInstanceOf[Iterator[String]].foreach { writeUTF(_, dataOut) } case pair: Tuple2[_, _] => pair._1 match { case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) + newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { + case (k, v) => + writeBytes(k) + writeBytes(v) } case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) + newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { + case (k, v) => + writeUTF(k, dataOut) + writeUTF(v, dataOut) } case other => throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) } case other => - throw new SparkException("Unexpected element type " + first.getClass) + if (other == null) { + dataOut.writeInt(SpecialLengths.NULL) + writeIteratorToStream(iter, dataOut) + } else { + throw new SparkException("Unexpected element type " + first.getClass) + } } } } @@ -527,9 +538,13 @@ private[spark] object PythonRDD extends Logging { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes(UTF8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + if (str == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + val bytes = str.getBytes(UTF8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } } def writeToFile[T](items: java.util.Iterator[T], filename: String) { diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 7b866f08a0e9f..d345115e327bd 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -23,11 +23,21 @@ import org.scalatest.FunSuite class PythonRDDSuite extends FunSuite { - test("Writing large strings to the worker") { - val input: List[String] = List("a"*100000) - val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(input.iterator, buffer) - } + test("Writing large strings to the worker") { + val input: List[String] = List("a" * 100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } -} + test("Handle nulls gracefully") { + val input: List[String] = List("a", null) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } + test("Handle list starts with nulls gracefully") { + val input: List[String] = List(null, null, "a", null) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } +} \ No newline at end of file diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 03b31ae9624c2..574ef926cec3d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -75,6 +75,7 @@ class SpecialLengths(object): END_OF_DATA_SECTION = -1 PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 + NULL = -4 class Serializer(object): @@ -336,6 +337,8 @@ class UTF8Deserializer(Serializer): def loads(self, stream): length = read_int(stream) + if length == SpecialLengths.NULL: + return None return stream.read(length).decode('utf8') def load_stream(self, stream):