Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,19 +406,21 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
if (row.numFields > 0) {
val st = fields.map(_.dataType)
val toUTF8StringFuncs = st.map(castToString)
if (row.isNullAt(0)) {
if (fields(0).nullable && row.isNullAt(0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fields(0).nullable is false, how can row.isNullAt(0) be true?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I have the same question)

Copy link
Contributor Author

@cfmcgrady cfmcgrady Jul 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If user create dataframe from spark.internalCreateDataFrame(), the row.isNullAt() may be true even though the schema nullable is false.
For instance:

  val schema = StructType(Seq(
    StructField("x",
      StructType(Seq(
        StructField("y", IntegerType, true),
        StructField("z", IntegerType, false)
      )))))
  val rdd = spark.sparkContext.parallelize(Seq(InternalRow(InternalRow(1, null))))
  val df = spark.internalCreateDataFrame(rdd, schema)
  df.show
  // current master branch output
  //  +---------+
  //  |        x|
  //  +---------+
  //  |{1, null}|
  //  +---------+

Although the spark.internalCreateDataFrame() is sql package private API, but spark.read.json() and spark.read.csv() call it without null value handled.(the example show in pr description)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we need to fix the nullability. There are so many places in the Spark codebase that relies on nullability to do optimizations. It's not possible to change all of them to not trust the nullability anymore.

Can we fix spark.read.json() to set the nullability correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let me try.

if (!legacyCastToStr) builder.append("null")
} else {
builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String])
val accessor = InternalRow.getAccessor(fields(0).dataType, fields(0).nullable)
builder.append(toUTF8StringFuncs(0)(accessor(row, 0)).asInstanceOf[UTF8String])
}
var i = 1
while (i < row.numFields) {
builder.append(",")
if (row.isNullAt(i)) {
if (fields(i).nullable && row.isNullAt(i)) {
if (!legacyCastToStr) builder.append(" null")
} else {
builder.append(" ")
builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String])
val accessor = InternalRow.getAccessor(fields(i).dataType, fields(i).nullable)
builder.append(toUTF8StringFuncs(i)(accessor(row, i)).asInstanceOf[UTF8String])
}
i += 1
}
Expand Down Expand Up @@ -868,8 +870,13 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
val newRow = new GenericInternalRow(from.fields.length)
var i = 0
while (i < row.numFields) {
newRow.update(i,
if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType)))
val value = if (from.fields(i).nullable && row.isNullAt(i)) {
null
} else {
val accessor = InternalRow.getAccessor(from.fields(i).dataType, from.fields(i).nullable)
castFuncs(i)(accessor(row, i))
}
newRow.update(i, value)
i += 1
}
newRow
Expand Down Expand Up @@ -1098,29 +1105,37 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
}

private def writeStructToStringBuilder(
st: Seq[DataType],
st: Seq[StructField],
row: ExprValue,
buffer: ExprValue,
ctx: CodegenContext): Block = {
val structToStringCode = st.zipWithIndex.map { case (ft, i) =>
val fieldToStringCode = castToStringCode(ft, ctx)
val field = ctx.freshVariable("field", ft)
val fieldStr = ctx.freshVariable("fieldStr", StringType)
val javaType = JavaCode.javaType(ft)
code"""
|${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock}
|if ($row.isNullAt($i)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the actual value is null, for primitive type field, row.isNullAt(i) return ture, but row.getXXX return a default value.

For exmaple:

val r = new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(Array(1, null))
println(r.getInt(0))   // 1
println(r.getInt(1))   // 0
println(r.isNullAt(1)) // true

so we cann't only check row.isNullAt(i) here, we need to do the same logical like BoundReference.doGenCode(), add nullable check.

| ${appendIfNotLegacyCastToStr(buffer, if (i == 0) "null" else " null")}
|} else {
| ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock}
|
| // Append $i field into the string buffer
| $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
|}
""".stripMargin
val structToStringCode = st.zipWithIndex.map {
case (StructField(_, dataType, nullable, _), i) =>
val fieldToStringCode = castToStringCode(dataType, ctx)
val field = ctx.freshVariable("field", dataType)
val fieldStr = ctx.freshVariable("fieldStr", StringType)
val javaType = JavaCode.javaType(dataType)

val isNull = if (nullable) {
code"$row.isNullAt($i)"
} else {
code"false"
}

code"""
|${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock}
|if ($isNull) {
| ${appendIfNotLegacyCastToStr(buffer, if (i == 0) "null" else " null")}
|} else {
| ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock}
|
| // Append $i field into the string buffer
| $javaType $field = ${CodeGenerator.getValue(row, dataType, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
|}
""".stripMargin
}

val writeStructCode = ctx.splitExpressions(
Expand Down Expand Up @@ -1184,7 +1199,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
val row = ctx.freshVariable("row", classOf[InternalRow])
val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder])
val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder])
val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx)
val writeStructCode = writeStructToStringBuilder(fields, row, buffer, ctx)
code"""
|InternalRow $row = $c;
|$bufferClass $buffer = new $bufferClass();
Expand Down Expand Up @@ -1890,8 +1905,15 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
val toFieldNull = ctx.freshVariable("tfn", BooleanType)
val fromType = JavaCode.javaType(from.fields(i).dataType)
val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)

val isNull = if (from.fields(i).nullable) {
code"boolean $fromFieldNull = $tmpInput.isNullAt($i);"
} else {
code"boolean $fromFieldNull = false;"
}

code"""
boolean $fromFieldNull = $tmpInput.isNullAt($i);
$isNull
if ($fromFieldNull) {
$tmpResult.setNullAt($i);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1271,4 +1271,49 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
}

test("SPARK-35912: Cast struct contains the null value to string") {
Seq(true, false).foreach { nullable =>
val lit = Literal.create(InternalRow(InternalRow(1, null)),
StructType(Seq(StructField("c1",
StructType(Seq(
StructField("c2", IntegerType, true),
StructField("c3", IntegerType, nullable)
))
)))
)
val ret = cast(lit, StringType)
assert(ret.resolved)
val expected = if (nullable) {
"{{1, null}}"
} else {
"{{1, 0}}"
}
checkEvaluation(ret, expected)
}
}

test("SPARK-35912: Cast struct contains the null value to struct") {
Seq(true, false).foreach { nullable =>
val lit = Literal.create(InternalRow(1, null),
StructType(Seq(
StructField("c1", IntegerType, true),
StructField("c2", IntegerType, nullable)
))
)
val toType = StructType(Seq(
StructField("c1", StringType, true),
StructField("c2", StringType, true)
))

val expected = if (nullable) {
InternalRow(UTF8String.fromString("1"), null)
} else {
InternalRow(UTF8String.fromString("1"), UTF8String.fromString("0"))
}
val ret = cast(lit, toType)
assert(ret.resolved)
checkEvaluation(ret, expected)
}
}
}