Skip to content

Commit 5fe321e

Browse files
committed
SPARK-4176 spark should save decimal values with precision >18 as parquets fixed_byte_array
Parquet defines multiple ways to store decimals. This patch enables the reading of all variations as well as writing decimals in the smallest fixed-length container possible (INT32, INT64, FIXED_LEN_BYTE_ARRAY).
1 parent 085a721 commit 5fe321e

File tree

5 files changed

+75
-58
lines changed

5 files changed

+75
-58
lines changed

sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,12 @@ private[parquet] class CatalystSchemaConverter(
169169
}
170170

171171
case INT96 =>
172-
CatalystSchemaConverter.analysisRequire(
173-
assumeInt96IsTimestamp,
174-
"INT96 is not supported unless it's interpreted as timestamp. " +
175-
s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
176-
TimestampType
172+
field.getOriginalType match {
173+
case DECIMAL => makeDecimalType(maxPrecisionForBytes(12))
174+
case _ if assumeInt96IsTimestamp => TimestampType
175+
case null => makeDecimalType(maxPrecisionForBytes(12))
176+
case _ => illegalType()
177+
}
177178

178179
case BINARY =>
179180
field.getOriginalType match {
@@ -373,8 +374,10 @@ private[parquet] class CatalystSchemaConverter(
373374

374375
// Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
375376
// always store decimals in fixed-length byte arrays.
377+
// Always storing FIXED_LEN_BYTE_ARRAY is thus compatible with spark <= 1.4.x, except for
378+
// precisions > 18.
376379
case DecimalType.Fixed(precision, scale)
377-
if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec =>
380+
if !followParquetFormatSpec =>
378381
Types
379382
.primitive(FIXED_LEN_BYTE_ARRAY, repetition)
380383
.as(DECIMAL)
@@ -383,30 +386,25 @@ private[parquet] class CatalystSchemaConverter(
383386
.length(minBytesForPrecision(precision))
384387
.named(field.name)
385388

386-
case dec @ DecimalType() if !followParquetFormatSpec =>
387-
throw new AnalysisException(
388-
s"Data type $dec is not supported. " +
389-
s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," +
390-
"decimal precision and scale must be specified, " +
391-
"and precision must be less than or equal to 18.")
392-
393389
// =====================================
394390
// Decimals (follow Parquet format spec)
395391
// =====================================
396392

397-
// Uses INT32 for 1 <= precision <= 9
393+
// Uses INT32 for 4 byte encodings / precision <= 9
398394
case DecimalType.Fixed(precision, scale)
399-
if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec =>
395+
if followParquetFormatSpec && maxPrecisionForBytes(3) < precision &&
396+
precision <= maxPrecisionForBytes(4) =>
400397
Types
401398
.primitive(INT32, repetition)
402399
.as(DECIMAL)
403400
.precision(precision)
404401
.scale(scale)
405402
.named(field.name)
406403

407-
// Uses INT64 for 1 <= precision <= 18
404+
// Uses INT64 for 8 byte encodings / precision <= 18
408405
case DecimalType.Fixed(precision, scale)
409-
if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec =>
406+
if followParquetFormatSpec && maxPrecisionForBytes(7) < precision &&
407+
precision <= maxPrecisionForBytes(8) =>
410408
Types
411409
.primitive(INT64, repetition)
412410
.as(DECIMAL)
@@ -562,4 +560,5 @@ private[parquet] object CatalystSchemaConverter {
562560
throw new AnalysisException(message)
563561
}
564562
}
563+
565564
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20+
import java.math.BigInteger
2021
import java.nio.ByteOrder
2122

2223
import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap}
@@ -241,26 +242,29 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
241242
def getCurrentRecord: InternalRow = throw new UnsupportedOperationException
242243

243244
/**
244-
* Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in
245-
* a long (i.e. precision <= 18)
245+
* Read a decimal value from a Parquet Binary into "dest".
246246
*
247247
* Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object.
248248
*/
249249
protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = {
250250
val precision = ctype.precisionInfo.get.precision
251251
val scale = ctype.precisionInfo.get.scale
252252
val bytes = value.getBytes
253-
require(bytes.length <= 16, "Decimal field too large to read")
254-
var unscaled = 0L
255-
var i = 0
256-
while (i < bytes.length) {
257-
unscaled = (unscaled << 8) | (bytes(i) & 0xFF)
258-
i += 1
253+
if (precision <= 18) {
254+
var unscaled = 0L
255+
var i = 0
256+
while (i < bytes.length) {
257+
unscaled = (unscaled << 8) | (bytes(i) & 0xFF)
258+
i += 1
259+
}
260+
// Make sure unscaled has the right sign, by sign-extending the first bit
261+
val numBits = 8 * bytes.length
262+
unscaled = (unscaled << (64 - numBits)) >> (64 - numBits)
263+
dest.set(unscaled, precision, scale)
264+
} else {
265+
val decimal = new java.math.BigDecimal(new BigInteger(bytes), scale)
266+
dest.set(new BigDecimal(decimal))
259267
}
260-
// Make sure unscaled has the right sign, by sign-extending the first bit
261-
val numBits = 8 * bytes.length
262-
unscaled = (unscaled << (64 - numBits)) >> (64 - numBits)
263-
dest.set(unscaled, precision, scale)
264268
}
265269

266270
/**

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
212212
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
213213
case DateType => writer.addInteger(value.asInstanceOf[Int])
214214
case d: DecimalType =>
215-
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
216-
sys.error(s"Unsupported datatype $d, cannot write to consumer")
217-
}
218-
writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
215+
writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.map(_.precision).getOrElse(10))
219216
case _ => sys.error(s"Do not know how to writer $schema to consumer")
220217
}
221218
}
@@ -297,19 +294,35 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
297294
}
298295

299296
// Scratch array used to write decimals as fixed-length binary
300-
private[this] val scratchBytes = new Array[Byte](8)
297+
private[this] val scratchBytes = new Array[Byte](4096)
301298

302299
private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
303300
val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
304-
val unscaledLong = decimal.toUnscaledLong
305-
var i = 0
306-
var shift = 8 * (numBytes - 1)
307-
while (i < numBytes) {
308-
scratchBytes(i) = (unscaledLong >> shift).toByte
309-
i += 1
310-
shift -= 8
301+
if (precision <= 18) {
302+
val unscaledLong = decimal.toUnscaledLong
303+
var i = 0
304+
var shift = 8 * (numBytes - 1)
305+
while (i < numBytes) {
306+
scratchBytes(i) = (unscaledLong >> shift).toByte
307+
i += 1
308+
shift -= 8
309+
}
310+
writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
311+
} else {
312+
val bytes = decimal.toBigDecimal.underlying.unscaledValue.toByteArray()
313+
val outBuffer =
314+
if (bytes.length == numBytes) {
315+
bytes
316+
} else {
317+
val b = if (numBytes <= scratchBytes.length) scratchBytes else new Array[Byte](numBytes)
318+
if (b == scratchBytes && numBytes < scratchBytes.length) {
319+
java.util.Arrays.fill(b, 0, numBytes - bytes.length, 0.toByte)
320+
}
321+
System.arraycopy(bytes, 0, b, numBytes - bytes.length, bytes.length)
322+
b
323+
}
324+
writer.addBinary(Binary.fromByteArray(outBuffer, 0, numBytes))
311325
}
312-
writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
313326
}
314327

315328
// array used to write Timestamp as Int96 (fixed-length binary)
@@ -367,10 +380,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
367380
case DateType => writer.addInteger(record.getInt(index))
368381
case TimestampType => writeTimestamp(record.getLong(index))
369382
case d: DecimalType =>
370-
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
371-
sys.error(s"Unsupported datatype $d, cannot write to consumer")
372-
}
373-
writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
383+
writeDecimal(record(index).asInstanceOf[Decimal],
384+
d.precisionInfo.map(_.precision).getOrElse(10))
374385
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
375386
}
376387
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,27 @@ private[parquet] object ParquetTypesConverter extends Logging {
4343
}
4444

4545
/**
46-
* Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision.
46+
* BYTES_FOR_PRECISION computes the required bytes to store a value of a certain decimal
47+
* precision.
4748
*/
48-
private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision =>
49-
var length = 1
49+
private[parquet] def BYTES_FOR_PRECISION_COMPUTE(precision : Int) : Int = {
50+
var length = (precision / math.log10(2) - 1).toInt / 8
5051
while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) {
5152
length += 1
5253
}
5354
length
5455
}
5556

57+
private[parquet] def BYTES_FOR_PRECISION_STATIC =
58+
(0 to 30).map(BYTES_FOR_PRECISION_COMPUTE).toArray
59+
60+
private[parquet] def BYTES_FOR_PRECISION(precision : Int) : Int =
61+
if (precision < BYTES_FOR_PRECISION_STATIC.length) {
62+
BYTES_FOR_PRECISION_STATIC(precision)
63+
} else {
64+
BYTES_FOR_PRECISION_COMPUTE(precision)
65+
}
66+
5667
def convertToAttributes(
5768
parquetSchema: MessageType,
5869
isBinaryAsString: Boolean,

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
107107
// Parquet doesn't allow column names with spaces, have to add an alias here
108108
.select($"_1" cast decimal as "dec")
109109

110-
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
110+
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19,0), (60, 5))) {
111111
withTempPath { dir =>
112112
val data = makeDecimalRDD(DecimalType(precision, scale))
113113
data.write.parquet(dir.getCanonicalPath)
114114
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
115115
}
116116
}
117117

118-
// Decimals with precision above 18 are not yet supported
119-
intercept[Throwable] {
120-
withTempPath { dir =>
121-
makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
122-
sqlContext.read.parquet(dir.getCanonicalPath).collect()
123-
}
124-
}
125-
126118
// Unlimited-length decimals are not yet supported
127119
intercept[Throwable] {
128120
withTempPath { dir =>

0 commit comments

Comments
 (0)