Skip to content

Commit eba6a1a

Browse files
viiryarxin
authored andcommitted
[SPARK-8945][SQL] Add add and subtract expressions for IntervalType
JIRA: https://issues.apache.org/jira/browse/SPARK-8945 Add add and subtract expressions for IntervalType. Author: Liang-Chi Hsieh <[email protected]> This patch had conflicts when merged, resolved by Committer: Reynold Xin <[email protected]> Closes apache#7398 from viirya/interval_add_subtract and squashes the following commits: acd1f1e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 5abae28 [Liang-Chi Hsieh] For comments. 6f5b72e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract dbe3906 [Liang-Chi Hsieh] For comments. 13a2fc5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract 83ec129 [Liang-Chi Hsieh] Remove intervalMethod. acfe1ab [Liang-Chi Hsieh] Fix scala style. d3e9d0e [Liang-Chi Hsieh] Add add and subtract expressions for IntervalType.
1 parent 305e77c commit eba6a1a

File tree

8 files changed

+136
-14
lines changed

8 files changed

+136
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2222
import org.apache.spark.sql.catalyst.util.TypeUtils
2323
import org.apache.spark.sql.types._
24+
import org.apache.spark.unsafe.types.Interval
2425

2526

2627
case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
2728

28-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
29+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
2930

3031
override def dataType: DataType = child.dataType
3132

@@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
3637
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
3738
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
3839
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
40+
case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
3941
}
4042

41-
protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
43+
protected override def nullSafeEval(input: Any): Any = {
44+
if (dataType.isInstanceOf[IntervalType]) {
45+
input.asInstanceOf[Interval].negate()
46+
} else {
47+
numeric.negate(input)
48+
}
49+
}
4250
}
4351

4452
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
4553
override def prettyName: String = "positive"
4654

47-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
55+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
4856

4957
override def dataType: DataType = child.dataType
5058

@@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic {
95103

96104
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
97105

98-
override def inputType: AbstractDataType = NumericType
106+
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
99107

100108
override def symbol: String = "+"
101-
override def decimalMethod: String = "$plus"
102109

103110
override lazy val resolved =
104111
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
105112

106113
private lazy val numeric = TypeUtils.getNumeric(dataType)
107114

108-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
115+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
116+
if (dataType.isInstanceOf[IntervalType]) {
117+
input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
118+
} else {
119+
numeric.plus(input1, input2)
120+
}
121+
}
122+
123+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
124+
case dt: DecimalType =>
125+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
126+
case ByteType | ShortType =>
127+
defineCodeGen(ctx, ev,
128+
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
129+
case IntervalType =>
130+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
131+
case _ =>
132+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
133+
}
109134
}
110135

111136
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
112137

113-
override def inputType: AbstractDataType = NumericType
138+
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
114139

115140
override def symbol: String = "-"
116-
override def decimalMethod: String = "$minus"
117141

118142
override lazy val resolved =
119143
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
120144

121145
private lazy val numeric = TypeUtils.getNumeric(dataType)
122146

123-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
147+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
148+
if (dataType.isInstanceOf[IntervalType]) {
149+
input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
150+
} else {
151+
numeric.minus(input1, input2)
152+
}
153+
}
154+
155+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
156+
case dt: DecimalType =>
157+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
158+
case ByteType | ShortType =>
159+
defineCodeGen(ctx, ev,
160+
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
161+
case IntervalType =>
162+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
163+
case _ =>
164+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
165+
}
124166
}
125167

126168
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.expressions._
2929
import org.apache.spark.sql.types._
30-
import org.apache.spark.unsafe.types.UTF8String
30+
import org.apache.spark.unsafe.types._
3131

3232

3333
// These classes are here to avoid issues with serialization and integration with quasiquotes.
@@ -69,6 +69,7 @@ class CodeGenContext {
6969
mutableStates += ((javaType, variableName, initialValue))
7070
}
7171

72+
final val intervalType: String = classOf[Interval].getName
7273
final val JAVA_BOOLEAN = "boolean"
7374
final val JAVA_BYTE = "byte"
7475
final val JAVA_SHORT = "short"
@@ -137,6 +138,7 @@ class CodeGenContext {
137138
case dt: DecimalType => "Decimal"
138139
case BinaryType => "byte[]"
139140
case StringType => "UTF8String"
141+
case IntervalType => intervalType
140142
case _: StructType => "InternalRow"
141143
case _: ArrayType => s"scala.collection.Seq"
142144
case _: MapType => s"scala.collection.Map"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2525
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2626
import org.apache.spark.sql.types._
27-
import org.apache.spark.unsafe.types.UTF8String
27+
import org.apache.spark.unsafe.types._
2828

2929
object Literal {
3030
def apply(v: Any): Literal = v match {
@@ -42,6 +42,7 @@ object Literal {
4242
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
4343
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
4444
case a: Array[Byte] => Literal(a, BinaryType)
45+
case i: Interval => Literal(i, IntervalType)
4546
case null => Literal(null, NullType)
4647
case _ =>
4748
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ private[sql] object TypeCollection {
9191
TimestampType, DateType,
9292
StringType, BinaryType)
9393

94+
/**
95+
* Types that include numeric types and interval type. They are only used in unary_minus,
96+
* unary_positive, add and subtract operations.
97+
*/
98+
val NumericAndInterval = TypeCollection(NumericType, IntervalType)
99+
94100
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
95101

96102
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
5353
}
5454

5555
test("check types for unary arithmetic") {
56-
assertError(UnaryMinus('stringField), "expected to be of type numeric")
56+
assertError(UnaryMinus('stringField), "type (numeric or interval)")
5757
assertError(Abs('stringField), "expected to be of type numeric")
5858
assertError(BitwiseNot('stringField), "expected to be of type integral")
5959
}
@@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
7878
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
7979
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
8080

81-
assertError(Add('booleanField, 'booleanField), "accepts numeric type")
82-
assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
81+
assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type")
82+
assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type")
8383
assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
8484
assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
8585
assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
14921492
// Currently we don't yet support nanosecond
14931493
checkIntervalParseError("select interval 23 nanosecond")
14941494
}
1495+
1496+
test("SPARK-8945: add and subtract expressions for interval type") {
1497+
import org.apache.spark.unsafe.types.Interval
1498+
1499+
val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
1500+
checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))
1501+
1502+
checkAnswer(df.select(df("i") + new Interval(2, 123)),
1503+
Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))
1504+
1505+
checkAnswer(df.select(df("i") - new Interval(2, 123)),
1506+
Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
1507+
1508+
// unary minus
1509+
checkAnswer(df.select(-df("i")),
1510+
Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
1511+
}
14951512
}

unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,22 @@ public Interval(int months, long microseconds) {
8787
this.microseconds = microseconds;
8888
}
8989

90+
public Interval add(Interval that) {
91+
int months = this.months + that.months;
92+
long microseconds = this.microseconds + that.microseconds;
93+
return new Interval(months, microseconds);
94+
}
95+
96+
public Interval subtract(Interval that) {
97+
int months = this.months - that.months;
98+
long microseconds = this.microseconds - that.microseconds;
99+
return new Interval(months, microseconds);
100+
}
101+
102+
public Interval negate() {
103+
return new Interval(-this.months, -this.microseconds);
104+
}
105+
90106
@Override
91107
public boolean equals(Object other) {
92108
if (this == other) return true;

unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,44 @@ public void fromStringTest() {
101101
assertEquals(Interval.fromString(input), null);
102102
}
103103

104+
@Test
105+
public void addTest() {
106+
String input = "interval 3 month 1 hour";
107+
String input2 = "interval 2 month 100 hour";
108+
109+
Interval interval = Interval.fromString(input);
110+
Interval interval2 = Interval.fromString(input2);
111+
112+
assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));
113+
114+
input = "interval -10 month -81 hour";
115+
input2 = "interval 75 month 200 hour";
116+
117+
interval = Interval.fromString(input);
118+
interval2 = Interval.fromString(input2);
119+
120+
assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
121+
}
122+
123+
@Test
124+
public void subtractTest() {
125+
String input = "interval 3 month 1 hour";
126+
String input2 = "interval 2 month 100 hour";
127+
128+
Interval interval = Interval.fromString(input);
129+
Interval interval2 = Interval.fromString(input2);
130+
131+
assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));
132+
133+
input = "interval -10 month -81 hour";
134+
input2 = "interval 75 month 200 hour";
135+
136+
interval = Interval.fromString(input);
137+
interval2 = Interval.fromString(input2);
138+
139+
assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
140+
}
141+
104142
private void testSingleUnit(String unit, int number, int months, long microseconds) {
105143
String input1 = "interval " + number + " " + unit;
106144
String input2 = "interval " + number + " " + unit + "s";

0 commit comments

Comments
 (0)