diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d905f8f9858e8..8ca3d356f3bdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -106,12 +106,12 @@ object RowEncoder { returnNullable = false) case d: DecimalType => - StaticInvoke( + CheckOverflow(StaticInvoke( Decimal.getClass, d, "fromDecimal", inputObject :: Nil, - returnNullable = false) + returnNullable = false), d) case StringType => StaticInvoke( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0f900833d2cfe..525c7cef39563 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1647,6 +1647,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(ds, data: _*) checkAnswer(ds.select("x"), Seq(Row(1), Row(2))) } + + test("SPARK-26233: serializer should enforce decimal precision and scale") { + val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8)))) + val encoder = RowEncoder(s) + implicit val uEnc = encoder + val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111))) + checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), + Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) + } } case class TestDataUnion(x: Int, y: Int, z: Int)