Skip to content

Commit 8cac615

Browse files
authored
[FLINK-22586][table] Improve the precision dedivation for decimal arithmetics
This closes #15848
1 parent 15e870d commit 8cac615

File tree

10 files changed

+216
-88
lines changed

10 files changed

+216
-88
lines changed

flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ public final class LogicalTypeMerging {
113113
YEAR_MONTH_RES_TO_BOUNDARIES = new HashMap<>();
114114
private static final Map<List<YearMonthResolution>, YearMonthResolution>
115115
YEAR_MONTH_BOUNDARIES_TO_RES = new HashMap<>();
116+
private static final int MINIMUM_ADJUSTED_SCALE = 6;
116117

117118
static {
118119
addYearMonthMapping(YEAR, YEAR);
@@ -198,50 +199,50 @@ public static Optional<LogicalType> findCommonType(List<LogicalType> types) {
198199
return Optional.empty();
199200
}
200201

202+
// ========================= Decimal Precision Deriving ==========================
203+
// Adopted from "https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-
204+
// scale-and-length-transact-sql"
205+
//
206+
// Operation Result Precision Result Scale
207+
// e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
208+
// e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
209+
// e1 * e2 p1 + p2 + 1 s1 + s2
210+
// e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
211+
// e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
212+
//
213+
// Also, if the precision / scale are out of the range, the scale may be sacrificed
214+
// in order to prevent the truncation of the integer part of the decimals.
215+
201216
/** Finds the result type of a decimal division operation. */
202217
public static DecimalType findDivisionDecimalType(
203218
int precision1, int scale1, int precision2, int scale2) {
204-
// adopted from
205-
// https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
206219
int scale = Math.max(6, scale1 + precision2 + 1);
207220
int precision = precision1 - scale1 + scale2 + scale;
208-
if (precision > DecimalType.MAX_PRECISION) {
209-
scale = Math.max(6, DecimalType.MAX_PRECISION - (precision - scale));
210-
precision = DecimalType.MAX_PRECISION;
211-
}
212-
return new DecimalType(false, precision, scale);
221+
return adjustPrecisionScale(precision, scale);
213222
}
214223

215224
/** Finds the result type of a decimal modulo operation. */
216225
public static DecimalType findModuloDecimalType(
217226
int precision1, int scale1, int precision2, int scale2) {
218-
// adopted from Calcite
219227
final int scale = Math.max(scale1, scale2);
220-
int precision =
221-
Math.min(precision1 - scale1, precision2 - scale2) + Math.max(scale1, scale2);
222-
precision = Math.min(precision, DecimalType.MAX_PRECISION);
223-
return new DecimalType(false, precision, scale);
228+
int precision = Math.min(precision1 - scale1, precision2 - scale2) + scale;
229+
return adjustPrecisionScale(precision, scale);
224230
}
225231

226232
/** Finds the result type of a decimal multiplication operation. */
227233
public static DecimalType findMultiplicationDecimalType(
228234
int precision1, int scale1, int precision2, int scale2) {
229-
// adopted from Calcite
230235
int scale = scale1 + scale2;
231-
scale = Math.min(scale, DecimalType.MAX_PRECISION);
232-
int precision = precision1 + precision2;
233-
precision = Math.min(precision, DecimalType.MAX_PRECISION);
234-
return new DecimalType(false, precision, scale);
236+
int precision = precision1 + precision2 + 1;
237+
return adjustPrecisionScale(precision, scale);
235238
}
236239

237240
/** Finds the result type of a decimal addition operation. */
238241
public static DecimalType findAdditionDecimalType(
239242
int precision1, int scale1, int precision2, int scale2) {
240-
// adopted from Calcite
241243
final int scale = Math.max(scale1, scale2);
242244
int precision = Math.max(precision1 - scale1, precision2 - scale2) + scale + 1;
243-
precision = Math.min(precision, DecimalType.MAX_PRECISION);
244-
return new DecimalType(false, precision, scale);
245+
return adjustPrecisionScale(precision, scale);
245246
}
246247

247248
/** Finds the result type of a decimal rounding operation. */
@@ -296,6 +297,27 @@ public static LogicalType findSumAggType(LogicalType argType) {
296297

297298
// --------------------------------------------------------------------------------------------
298299

300+
/**
301+
* Scale adjustment implementation is inspired to SQLServer's one. In particular, when a result
302+
* precision is greater than MAX_PRECISION, the corresponding scale is reduced to prevent the
303+
* integral part of a result from being truncated.
304+
*
305+
* <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql
306+
*/
307+
private static DecimalType adjustPrecisionScale(int precision, int scale) {
308+
if (precision <= DecimalType.MAX_PRECISION) {
309+
// Adjustment only needed when we exceed max precision
310+
return new DecimalType(false, precision, scale);
311+
} else {
312+
int digitPart = precision - scale;
313+
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value;
314+
// otherwise preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
315+
int minScalePart = Math.min(scale, MINIMUM_ADJUSTED_SCALE);
316+
int adjustScale = Math.max(DecimalType.MAX_PRECISION - digitPart, minScalePart);
317+
return new DecimalType(false, DecimalType.MAX_PRECISION, adjustScale);
318+
}
319+
}
320+
299321
private static @Nullable LogicalType findCommonCastableType(List<LogicalType> normalizedTypes) {
300322
LogicalType resultType = normalizedTypes.get(0);
301323

flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/TypeStrategiesTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public static List<TestSpec> testData() {
197197
.expectDataType(DataTypes.DECIMAL(11, 8).notNull()),
198198
TestSpec.forStrategy("Find a decimal product", TypeStrategies.DECIMAL_TIMES)
199199
.inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2))
200-
.expectDataType(DataTypes.DECIMAL(8, 6).notNull()),
200+
.expectDataType(DataTypes.DECIMAL(9, 6).notNull()),
201201
TestSpec.forStrategy("Find a decimal modulo", TypeStrategies.DECIMAL_MOD)
202202
.inputTypes(DataTypes.DECIMAL(5, 4), DataTypes.DECIMAL(3, 2))
203203
.expectDataType(DataTypes.DECIMAL(5, 4).notNull()),

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/AvgAggFunction.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.apache.flink.table.api.DataTypes;
2222
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedCallExpression;
2324
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2425
import org.apache.flink.table.types.DataType;
2526
import org.apache.flink.table.types.logical.DecimalType;
@@ -72,26 +73,31 @@ public Expression[] initialValuesExpressions() {
7273
@Override
7374
public Expression[] accumulateExpressions() {
7475
return new Expression[] {
75-
/* sum = */ ifThenElse(isNull(operand(0)), sum, plus(sum, operand(0))),
76+
/* sum = */ adjustSumType(ifThenElse(isNull(operand(0)), sum, plus(sum, operand(0)))),
7677
/* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L))),
7778
};
7879
}
7980

8081
@Override
8182
public Expression[] retractExpressions() {
8283
return new Expression[] {
83-
/* sum = */ ifThenElse(isNull(operand(0)), sum, minus(sum, operand(0))),
84+
/* sum = */ adjustSumType(ifThenElse(isNull(operand(0)), sum, minus(sum, operand(0)))),
8485
/* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L))),
8586
};
8687
}
8788

8889
@Override
8990
public Expression[] mergeExpressions() {
9091
return new Expression[] {
91-
/* sum = */ plus(sum, mergeOperand(sum)), /* count = */ plus(count, mergeOperand(count))
92+
/* sum = */ adjustSumType(plus(sum, mergeOperand(sum))),
93+
/* count = */ plus(count, mergeOperand(count))
9294
};
9395
}
9496

97+
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
98+
return cast(sumExpr, typeLiteral(getSumType()));
99+
}
100+
95101
/** If all input are nulls, count will be 0 and we will get null after the division. */
96102
@Override
97103
public Expression getValueExpression() {

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.apache.flink.table.api.DataTypes;
2222
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedCallExpression;
2324
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2425
import org.apache.flink.table.types.DataType;
2526
import org.apache.flink.table.types.logical.DecimalType;
@@ -28,11 +29,13 @@
2829
import java.math.BigDecimal;
2930

3031
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
32+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
3133
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
3234
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
3335
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
3436
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
3537
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
38+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
3639

3740
/** built-in sum0 aggregate function. */
3841
public abstract class Sum0AggFunction extends DeclarativeAggregateFunction {
@@ -56,20 +59,25 @@ public DataType[] getAggBufferTypes() {
5659
@Override
5760
public Expression[] accumulateExpressions() {
5861
return new Expression[] {
59-
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0)))
62+
/* sum0 = */ adjustSumType(ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0))))
6063
};
6164
}
6265

6366
@Override
6467
public Expression[] retractExpressions() {
6568
return new Expression[] {
66-
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0)))
69+
/* sum0 = */ adjustSumType(
70+
ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0))))
6771
};
6872
}
6973

7074
@Override
7175
public Expression[] mergeExpressions() {
72-
return new Expression[] {/* sum0 = */ plus(sum0, mergeOperand(sum0))};
76+
return new Expression[] {/* sum0 = */ adjustSumType(plus(sum0, mergeOperand(sum0)))};
77+
}
78+
79+
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
80+
return cast(sumExpr, typeLiteral(getResultType()));
7381
}
7482

7583
@Override

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,19 @@
2121
import org.apache.flink.table.api.DataTypes;
2222
import org.apache.flink.table.api.TableException;
2323
import org.apache.flink.table.expressions.Expression;
24+
import org.apache.flink.table.expressions.UnresolvedCallExpression;
2425
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2526
import org.apache.flink.table.types.DataType;
2627
import org.apache.flink.table.types.logical.DecimalType;
2728
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
2829

2930
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
31+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
3032
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
3133
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
3234
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
3335
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
36+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
3437

3538
/** built-in sum aggregate function. */
3639
public abstract class SumAggFunction extends DeclarativeAggregateFunction {
@@ -59,10 +62,11 @@ public Expression[] initialValuesExpressions() {
5962
@Override
6063
public Expression[] accumulateExpressions() {
6164
return new Expression[] {
62-
/* sum = */ ifThenElse(
63-
isNull(operand(0)),
64-
sum,
65-
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))
65+
/* sum = */ adjustSumType(
66+
ifThenElse(
67+
isNull(operand(0)),
68+
sum,
69+
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))))
6670
};
6771
}
6872

@@ -75,13 +79,19 @@ public Expression[] retractExpressions() {
7579
@Override
7680
public Expression[] mergeExpressions() {
7781
return new Expression[] {
78-
/* sum = */ ifThenElse(
79-
isNull(mergeOperand(sum)),
80-
sum,
81-
ifThenElse(isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))
82+
/* sum = */ adjustSumType(
83+
ifThenElse(
84+
isNull(mergeOperand(sum)),
85+
sum,
86+
ifThenElse(
87+
isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))))
8288
};
8389
}
8490

91+
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
92+
return cast(sumExpr, typeLiteral(getResultType()));
93+
}
94+
8595
@Override
8696
public Expression getValueExpression() {
8797
return sum;

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumWithRetractAggFunction.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,22 @@
2020

2121
import org.apache.flink.table.api.DataTypes;
2222
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedCallExpression;
2324
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2425
import org.apache.flink.table.types.DataType;
2526
import org.apache.flink.table.types.logical.DecimalType;
2627
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
2728

2829
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
30+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
2931
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.equalTo;
3032
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
3133
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
3234
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
3335
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
3436
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
3537
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
38+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
3639

3740
/** built-in sum aggregate function with retraction. */
3841
public abstract class SumWithRetractAggFunction extends DeclarativeAggregateFunction {
@@ -62,37 +65,47 @@ public Expression[] initialValuesExpressions() {
6265
@Override
6366
public Expression[] accumulateExpressions() {
6467
return new Expression[] {
65-
/* sum = */ ifThenElse(
66-
isNull(operand(0)),
67-
sum,
68-
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0)))),
68+
/* sum = */ adjustSumType(
69+
ifThenElse(
70+
isNull(operand(0)),
71+
sum,
72+
ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))),
6973
/* count = */ ifThenElse(isNull(operand(0)), count, plus(count, literal(1L)))
7074
};
7175
}
7276

7377
@Override
7478
public Expression[] retractExpressions() {
7579
return new Expression[] {
76-
/* sum = */ ifThenElse(
77-
isNull(operand(0)),
78-
sum,
80+
/* sum = */ adjustSumType(
7981
ifThenElse(
80-
isNull(sum), minus(zeroLiteral(), operand(0)), minus(sum, operand(0)))),
82+
isNull(operand(0)),
83+
sum,
84+
ifThenElse(
85+
isNull(sum),
86+
minus(zeroLiteral(), operand(0)),
87+
minus(sum, operand(0))))),
8188
/* count = */ ifThenElse(isNull(operand(0)), count, minus(count, literal(1L)))
8289
};
8390
}
8491

8592
@Override
8693
public Expression[] mergeExpressions() {
8794
return new Expression[] {
88-
/* sum = */ ifThenElse(
89-
isNull(mergeOperand(sum)),
90-
sum,
91-
ifThenElse(isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum)))),
95+
/* sum = */ adjustSumType(
96+
ifThenElse(
97+
isNull(mergeOperand(sum)),
98+
sum,
99+
ifThenElse(
100+
isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))),
92101
/* count = */ plus(count, mergeOperand(count))
93102
};
94103
}
95104

105+
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
106+
return cast(sumExpr, typeLiteral(getResultType()));
107+
}
108+
96109
@Override
97110
public Expression getValueExpression() {
98111
return ifThenElse(equalTo(count, literal(0L)), nullOf(getResultType()), sum);

0 commit comments

Comments
 (0)