diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 3b8dde1823370..4e3583a94474e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -38,7 +38,11 @@ private[sql] object JsonRDD extends Logging { json: RDD[String], schema: StructType, columnNameOfCorruptRecords: String): RDD[Row] = { - parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema)) + // Reuse the mutable row for each record and all innner nested structures + parseJson(json, columnNameOfCorruptRecords).mapPartitions { + val row = new GenericMutableRow(schema.fields.length) + iter => iter.map(parsed => asRow(parsed, schema, row)) + } } private[sql] def inferSchema( @@ -401,7 +405,8 @@ private[sql] object JsonRDD extends Logging { } } - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + private[json] def enforceCorrectType( + value: Any, desiredType: DataType, slot: Any = null): Any = { if (value == null) { null } else { @@ -414,22 +419,41 @@ private[sql] object JsonRDD extends Logging { case DecimalType() => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) - case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) + case ArrayType(elementType, _) => { + val arrayLength = value.asInstanceOf[Seq[Any]].length + val arraySlot = if (slot != null && slot.asInstanceOf[Seq[Any]].size == arrayLength) { + slot.asInstanceOf[Seq[Any]] + } else { + (new Array[Any](arrayLength)).toSeq + } + value.asInstanceOf[Seq[Any]].zip(arraySlot).map { + case (v, s) => enforceCorrectType(v, elementType,s) + }.toList + } + case struct: StructType => + asRow(value.asInstanceOf[Map[String, Any]], struct, slot.asInstanceOf[GenericMutableRow]) case DateType => toDate(value) case TimestampType => toTimestamp(value) } } } - private def asRow(json: Map[String,Any], schema: StructType): Row = { - // TODO: Reuse the row instead of creating a new one for every record. - val row = new GenericMutableRow(schema.fields.length) - schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).orNull) + private def asRow( + json: Map[String,Any], + schema: StructType, + mutable: GenericMutableRow = null): Row = { + + val row = if (mutable != null && mutable.length == schema.fields.length) { + mutable + } else { + new GenericMutableRow(schema.fields.length) + } + + for(i <- 0 until schema.fields.length) { + val fieldName = schema.fields(i).name + val fieldType = schema.fields(i).dataType + row.update(i, json.get(fieldName).flatMap(v => Option(v)).map( + enforceCorrectType(_, fieldType, row(i))).orNull) } row