Skip to content

Commit f03ef93

Browse files
committed
Fixes a bug when writing small decimals coming from rows that are not UnsafeRow
1 parent 5b08a20 commit f03ef93

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,16 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi
197197
val numBytes = minBytesForPrecision(precision)
198198

199199
val int32Writer =
200-
(row: SpecializedGetters, ordinal: Int) =>
201-
recordConsumer.addInteger(row.getLong(ordinal).toInt)
200+
(row: SpecializedGetters, ordinal: Int) => {
201+
val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong
202+
recordConsumer.addInteger(unscaledLong.toInt)
203+
}
202204

203205
val int64Writer =
204-
(row: SpecializedGetters, ordinal: Int) =>
205-
recordConsumer.addLong(row.getLong(ordinal))
206+
(row: SpecializedGetters, ordinal: Int) => {
207+
val unscaledLong = row.getDecimal(ordinal, precision, scale).toUnscaledLong
208+
recordConsumer.addLong(unscaledLong)
209+
}
206210

207211
val binaryWriterUsingUnscaledLong =
208212
(row: SpecializedGetters, ordinal: Int) => {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
100100
}
101101

102102
testStandardAndLegacyModes("fixed-length decimals") {
103-
def makeDecimalRDD(decimal: DecimalType): DataFrame =
104-
sparkContext
105-
.parallelize(0 to 1000)
106-
.map(i => Tuple1((i - 500) / 100.0))
107-
.toDF()
108-
// Parquet doesn't allow column names with spaces, have to add an alias here
109-
.select($"_1" cast decimal as "dec")
103+
def makeDecimalRDD(decimal: DecimalType): DataFrame = {
104+
sqlContext
105+
.range(1000)
106+
// Parquet doesn't allow column names with spaces, have to add an alias here.
107+
// Minus 500 here so that negative decimals are also tested.
108+
.select((('id - 500) / 100.0) cast decimal as 'dec)
109+
.coalesce(1)
110+
}
110111

111112
val combinations = Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))
112113
for ((precision, scale) <- combinations) {

0 commit comments

Comments
 (0)