Skip to content

Commit dbe3906

Browse files
committed
For comments.
1 parent 13a2fc5 commit dbe3906

File tree

5 files changed

+54
-23
lines changed

5 files changed

+54
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.Interval
2626

2727
case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
2828

29-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
29+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
3030

3131
override def dataType: DataType = child.dataType
3232

@@ -37,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
3737
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
3838
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
3939
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
40+
case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
4041
}
4142

42-
protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
43+
protected override def nullSafeEval(input: Any): Any = {
44+
if (dataType.isInstanceOf[IntervalType]) {
45+
input.asInstanceOf[Interval].negate()
46+
} else {
47+
numeric.negate(input)
48+
}
49+
}
4350
}
4451

4552
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
4653
override def prettyName: String = "positive"
4754

48-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
55+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
4956

5057
override def dataType: DataType = child.dataType
5158

@@ -85,8 +92,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
8592
case ByteType | ShortType =>
8693
defineCodeGen(ctx, ev,
8794
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
88-
case IntervalType =>
89-
defineCodeGen(ctx, ev, (eval1, eval2) => s"""$eval1.doOp($eval2, "$symbol")""")
9095
case _ =>
9196
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
9297
}
@@ -98,8 +103,7 @@ private[sql] object BinaryArithmetic {
98103

99104
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
100105

101-
override def inputType: AbstractDataType = NumericType
102-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntervalType)
106+
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
103107

104108
override def symbol: String = "+"
105109
override def decimalMethod: String = "$plus"
@@ -116,12 +120,23 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
116120
numeric.plus(input1, input2)
117121
}
118122
}
123+
124+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
125+
case dt: DecimalType =>
126+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
127+
case ByteType | ShortType =>
128+
defineCodeGen(ctx, ev,
129+
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
130+
case IntervalType =>
131+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
132+
case _ =>
133+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
134+
}
119135
}
120136

121137
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
122138

123-
override def inputType: AbstractDataType = NumericType
124-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntervalType)
139+
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
125140

126141
override def symbol: String = "-"
127142
override def decimalMethod: String = "$minus"
@@ -138,6 +153,18 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
138153
numeric.minus(input1, input2)
139154
}
140155
}
156+
157+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
158+
case dt: DecimalType =>
159+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
160+
case ByteType | ShortType =>
161+
defineCodeGen(ctx, ev,
162+
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
163+
case IntervalType =>
164+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
165+
case _ =>
166+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
167+
}
141168
}
142169

143170
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ private[sql] object TypeCollection {
114114
BooleanType,
115115
ByteType, ShortType, IntegerType, LongType)
116116

117+
/**
118+
* Types that include numeric types and interval type. They are only used in unary_minus,
119+
* unary_positive, add and subtract operations.
120+
*/
121+
val NumericAndInterval = TypeCollection(
122+
ByteType, ShortType, IntegerType, LongType,
123+
FloatType, DoubleType, DecimalType, IntervalType)
124+
117125
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
118126

119127
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
5353
}
5454

5555
test("check types for unary arithmetic") {
56-
assertError(UnaryMinus('stringField), "argument 1 is expected to be of type numeric")
56+
assertError(UnaryMinus('stringField), "argument 1 is expected to be of type (tinyint or " +
57+
"smallint or int or bigint or float or double or decimal or interval")
5758
assertError(Abs('stringField), "argument 1 is expected to be of type numeric")
5859
assertError(BitwiseNot('stringField), "argument 1 is expected to be of type (boolean " +
5960
"or tinyint or smallint or int or bigint)")

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,5 +1504,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
15041504

15051505
checkAnswer(df.select(df("i") - new Interval(2, 123)),
15061506
Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
1507+
1508+
// unary minus
1509+
checkAnswer(df.select(-df("i")),
1510+
Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
15071511
}
15081512
}

unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,6 @@ public Interval(int months, long microseconds) {
8686
this.microseconds = microseconds;
8787
}
8888

89-
public Interval doOp(Interval that, String op) {
90-
Interval opRet = null;
91-
switch (op) {
92-
case "+":
93-
opRet = add(that);
94-
break;
95-
case "-":
96-
opRet = subtract(that);
97-
break;
98-
}
99-
return opRet;
100-
}
101-
10289
public Interval add(Interval that) {
10390
int months = this.months + that.months;
10491
long microseconds = this.microseconds + that.microseconds;
@@ -111,6 +98,10 @@ public Interval subtract(Interval that) {
11198
return new Interval(months, microseconds);
11299
}
113100

101+
public Interval negate() {
102+
return new Interval(-this.months, -this.microseconds);
103+
}
104+
114105
@Override
115106
public boolean equals(Object other) {
116107
if (this == other) return true;

0 commit comments

Comments
 (0)