diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 672c344a56597..29a0f7889659c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -245,7 +245,7 @@ private object SpecialLengths { val TIMING_DATA = -3 } -private[spark] object PythonRDD { +private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -301,15 +301,23 @@ private[spark] object PythonRDD { throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) } case other => - throw new SparkException("Unexpected element type " + first.getClass) + if (other == null) { + logDebug("Encountered NULL element from iterator. We skip writing NULL to stream.") + } else { + throw new SparkException("Unexpected element type " + first.getClass) + } } } } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes(UTF8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + if (str == null) { + logDebug("Encountered NULL string. We skip writing NULL to stream.") + } else { + val bytes = str.getBytes(UTF8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } } def writeToFile[T](items: java.util.Iterator[T], filename: String) { diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 7b866f08a0e9f..c62f341441e53 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -23,11 +23,16 @@ import org.scalatest.FunSuite class PythonRDDSuite extends FunSuite { - test("Writing large strings to the worker") { - val input: List[String] = List("a"*100000) - val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(input.iterator, buffer) - } + test("Writing large strings to the worker") { + val input: List[String] = List("a"*100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } + test("Handle nulls gracefully") { + val input: List[String] = List("a", null) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } }