Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -515,10 +516,15 @@ case class JsonToStructs(
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

def this(schema: DataType, options: Map[String, String], child: Expression) =
this(schema, options, child, None)
val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)

// The JSON input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
// can generate incorrect files if values are missing in columns declared as non-nullable.
val nullableSchema = if (forceNullableSchema) schema.asNullable else schema

override def nullable: Boolean = true

// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression) =
Expand All @@ -535,22 +541,22 @@ case class JsonToStructs(
child = child,
timeZoneId = None)

override def checkInputDataTypes(): TypeCheckResult = schema match {
override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
case _: StructType | ArrayType(_: StructType, _) =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.")
}

@transient
lazy val rowSchema = schema match {
lazy val rowSchema = nullableSchema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
}

// This converts parsed rows to the desired output by the given schema.
@transient
lazy val converter = schema match {
lazy val converter = nullableSchema match {
case _: StructType =>
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
case ArrayType(_: StructType, _) =>
Expand All @@ -563,7 +569,7 @@ case class JsonToStructs(
rowSchema,
new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get))

override def dataType: DataType = schema
override def dataType: DataType = nullableSchema

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,14 @@ object SQLConf {
.stringConf
.createWithDefault("_corrupt_record")

val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema")
.internal()
.doc("When true, force the output schema of the from_json() function to be nullable " +
"(including all the fields). Otherwise, the schema might not be compatible with" +
"actual data, which leads to curruptions.")
.booleanConf
.createWithDefault(true)

val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout")
.doc("Timeout in seconds for the broadcast wait time in broadcast joins.")
.timeConf(TimeUnit.SECONDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase {
val json =
"""
|{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],
Expand Down Expand Up @@ -680,4 +682,31 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}
}

test("from_json missing fields") {
for (forceJsonNullableSchema <- Seq(false, true)) {
withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) {
val input =
"""{
| "a": 1,
| "c": "foo"
|}
|"""
.stripMargin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the style.

val jsonSchema = new StructType()
.add("a", LongType, nullable = false)
.add("b", StringType, nullable = false)
.add("c", StringType, nullable = false)
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
checkEvaluation(
JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId),
output
)
val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
.dataType
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
assert(schemaToCompare == schema);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: ; is useless.

}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -780,6 +781,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
assert(option.compressionCodecClassName == "UNCOMPRESSED")
}
}

test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") {
withTempPath { file =>
val jsonData =
"""{
| "a": 1,
| "c": "foo"
|}
|"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here.

.stripMargin
val jsonSchema = new StructType()
.add("a", LongType, nullable = false)
.add("b", StringType, nullable = false)
.add("c", StringType, nullable = false)
spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input")
.write.parquet(file.getAbsolutePath)
checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo"))))
}
}
}

class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
Expand Down