Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ object RowEncoder {
returnNullable = false)

case d: DecimalType =>
StaticInvoke(
CheckOverflow(StaticInvoke(
Decimal.getClass,
d,
"fromDecimal",
inputObject :: Nil,
returnNullable = false)
returnNullable = false), d)

case StringType =>
StaticInvoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(ds, data: _*)
checkAnswer(ds.select("x"), Seq(Row(1), Row(2)))
}

test("SPARK-26233: serializer should enforce decimal precision and scale") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test case in RowEncoderSuite, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, everything is possible, but it is not easy actually. Because the issue here happens in the codegen, not when we retrieve the output. So if we just encode and decode everything is fine. The problem happens if there is any transformation in the codegen meanwhile, because there the underlying decimal is used (assuming that it has the same precision and scale of the data type - which without the current change is not always true). I tried checking the precision and scale of the serialized object, but it is not really feasible as they are converted when it is read (please see UnsafeRow)... So I'd avoid this actually.

val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8))))
val encoder = RowEncoder(s)
implicit val uEnc = encoder
val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111)))
checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down