Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ object JavaTypeInference {
inputObject,
ObjectType(keyType.getRawType),
serializerFor(_, keyType),
keyNullable = true,
ObjectType(valueType.getRawType),
serializerFor(_, valueType),
valueNullable = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ object ScalaReflection extends ScalaReflection {
inputObject,
dataTypeFor(keyType),
serializerFor(_, keyType, keyPath, seenTypeSet),
keyNullable = !keyType.typeSymbol.asClass.isPrimitive,
dataTypeFor(valueType),
serializerFor(_, valueType, valuePath, seenTypeSet),
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,18 +841,21 @@ object ExternalMapToCatalyst {
inputMap: Expression,
keyType: DataType,
keyConverter: Expression => Expression,
keyNullable: Boolean,
valueType: DataType,
valueConverter: Expression => Expression,
valueNullable: Boolean): ExternalMapToCatalyst = {
val id = curId.getAndIncrement()
val keyName = "ExternalMapToCatalyst_key" + id
val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id
val valueName = "ExternalMapToCatalyst_value" + id
val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id

ExternalMapToCatalyst(
keyName,
keyIsNull,
keyType,
keyConverter(LambdaVariable(keyName, "false", keyType, false)),
keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)),
valueName,
valueIsNull,
valueType,
Expand All @@ -868,6 +871,8 @@ object ExternalMapToCatalyst {
*
* @param key the name of the map key variable that used when iterate the map, and used as input for
* the `keyConverter`
* @param keyIsNull the nullability of the map key variable that used when iterate the map, and
* used as input for the `keyConverter`
* @param keyType the data type of the map key variable that used when iterate the map, and used as
* input for the `keyConverter`
* @param keyConverter A function that take the `key` as input, and converts it to catalyst format.
Expand All @@ -883,6 +888,7 @@ object ExternalMapToCatalyst {
*/
case class ExternalMapToCatalyst private(
key: String,
keyIsNull: String,
keyType: DataType,
keyConverter: Expression,
value: String,
Expand Down Expand Up @@ -913,6 +919,7 @@ case class ExternalMapToCatalyst private(

val keyElementJavaType = ctx.javaType(keyType)
val valueElementJavaType = ctx.javaType(valueType)
ctx.addMutableState("boolean", keyIsNull, "")
ctx.addMutableState(keyElementJavaType, key, "")
ctx.addMutableState("boolean", valueIsNull, "")
ctx.addMutableState(valueElementJavaType, value, "")
Expand Down Expand Up @@ -950,6 +957,12 @@ case class ExternalMapToCatalyst private(
defineEntries -> defineKeyValue
}

val keyNullCheck = if (ctx.isPrimitiveType(keyType)) {
s"$keyIsNull = false;"
} else {
s"$keyIsNull = $key == null;"
}

val valueNullCheck = if (ctx.isPrimitiveType(valueType)) {
s"$valueIsNull = false;"
} else {
Expand All @@ -972,6 +985,7 @@ case class ExternalMapToCatalyst private(
$defineEntries
while($entries.hasNext()) {
$defineKeyValue
$keyNullCheck
$valueNullCheck

${genKeyConverter.code}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
checkNullable[String](true)
}

test("null check for map key") {
test("null check for map key: String") {
val encoder = ExpressionEncoder[Map[String, Int]]()
val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))
assert(e.getMessage.contains("Cannot use null as map key"))
}

test("null check for map key: Integer") {
val encoder = ExpressionEncoder[Map[Integer, String]]()
val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b"))))
assert(e.getMessage.contains("Cannot use null as map key"))
}

private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
testName: String): Unit = {
Expand Down