Skip to content

Commit d3e9d0e

Browse files
committed
Add add and subtract expressions for IntervalType.
1 parent 408b384 commit d3e9d0e

File tree

7 files changed

+101
-6
lines changed

7 files changed

+101
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2323
import org.apache.spark.sql.catalyst.util.TypeUtils
2424
import org.apache.spark.sql.types._
25+
import org.apache.spark.unsafe.types.Interval
2526

2627
abstract class UnaryArithmetic extends UnaryExpression {
2728
self: Product =>
@@ -87,13 +88,19 @@ abstract class BinaryArithmetic extends BinaryOperator {
8788
def decimalMethod: String =
8889
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
8990

91+
/** Name of the function for this expression on a [[Interval]] type. */
92+
def intervalMethod: String =
93+
sys.error("BinaryArithmetics must override either intervalMethod or genCode")
94+
9095
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
9196
case dt: DecimalType =>
9297
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
9398
// byte and short are casted into int when add, minus, times or divide
9499
case ByteType | ShortType =>
95100
defineCodeGen(ctx, ev,
96101
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
102+
case IntervalType =>
103+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$intervalMethod($eval2)")
97104
case _ =>
98105
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
99106
}
@@ -106,31 +113,45 @@ private[sql] object BinaryArithmetic {
106113
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
107114
override def symbol: String = "+"
108115
override def decimalMethod: String = "$plus"
116+
override def intervalMethod: String = "add"
109117

110118
override lazy val resolved =
111119
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
112120

113121
protected def checkTypesInternal(t: DataType) =
114-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
122+
TypeUtils.checkForNumericAndIntervalExpr(t, "operator " + symbol)
115123

116124
private lazy val numeric = TypeUtils.getNumeric(dataType)
117125

118-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
126+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
127+
if (dataType.isInstanceOf[IntervalType]) {
128+
input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
129+
} else {
130+
numeric.plus(input1, input2)
131+
}
132+
}
119133
}
120134

121135
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
122136
override def symbol: String = "-"
123137
override def decimalMethod: String = "$minus"
138+
override def intervalMethod: String = "subtract"
124139

125140
override lazy val resolved =
126141
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
127142

128143
protected def checkTypesInternal(t: DataType) =
129-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
144+
TypeUtils.checkForNumericAndIntervalExpr(t, "operator " + symbol)
130145

131146
private lazy val numeric = TypeUtils.getNumeric(dataType)
132147

133-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
148+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
149+
if (dataType.isInstanceOf[IntervalType]) {
150+
input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
151+
} else {
152+
numeric.minus(input1, input2)
153+
}
154+
}
134155
}
135156

136157
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.codehaus.janino.ClassBodyEvaluator
2626
import org.apache.spark.Logging
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.types._
29-
import org.apache.spark.unsafe.types.UTF8String
29+
import org.apache.spark.unsafe.types._
3030

3131

3232
// These classes are here to avoid issues with serialization and integration with quasiquotes.
@@ -57,6 +57,7 @@ class CodeGenContext {
5757
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
5858

5959
val stringType: String = classOf[UTF8String].getName
60+
val intervalType: String = classOf[Interval].getName
6061
val decimalType: String = classOf[Decimal].getName
6162

6263
final val JAVA_BOOLEAN = "boolean"
@@ -127,6 +128,7 @@ class CodeGenContext {
127128
case dt: DecimalType => decimalType
128129
case BinaryType => "byte[]"
129130
case StringType => stringType
131+
case IntervalType => intervalType
130132
case _: StructType => "InternalRow"
131133
case _: ArrayType => s"scala.collection.Seq"
132134
case _: MapType => s"scala.collection.Map"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2525
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2626
import org.apache.spark.sql.types._
27-
import org.apache.spark.unsafe.types.UTF8String
27+
import org.apache.spark.unsafe.types._
2828

2929
object Literal {
3030
def apply(v: Any): Literal = v match {
@@ -42,6 +42,7 @@ object Literal {
4242
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
4343
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
4444
case a: Array[Byte] => Literal(a, BinaryType)
45+
case i: Interval => Literal(i, IntervalType)
4546
case null => Literal(null, NullType)
4647
case _ =>
4748
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ object TypeUtils {
3232
}
3333
}
3434

35+
def checkForNumericAndIntervalExpr(t: DataType, caller: String): TypeCheckResult = {
36+
if (t.isInstanceOf[NumericType] || t.isInstanceOf[IntervalType] || t == NullType) {
37+
TypeCheckResult.TypeCheckSuccess
38+
} else {
39+
TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric or interval types, not $t")
40+
}
41+
}
42+
3543
def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = {
3644
if (t.isInstanceOf[IntegralType] || t == NullType) {
3745
TypeCheckResult.TypeCheckSuccess

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,4 +1492,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
14921492
// Currently we don't yet support nanosecond
14931493
checkIntervalParseError("select interval 23 nanosecond")
14941494
}
1495+
1496+
test("SPARK-8945: add and subtract expressions for interval type") {
1497+
import org.apache.spark.unsafe.types.Interval
1498+
1499+
val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
1500+
checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))
1501+
1502+
checkAnswer(df.select(df("i") + new Interval(2, 123)),
1503+
Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))
1504+
1505+
checkAnswer(df.select(df("i") - new Interval(2, 123)),
1506+
Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
1507+
}
14951508
}

unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ public Interval(int months, long microseconds) {
8686
this.microseconds = microseconds;
8787
}
8888

89+
public Interval add(Interval that) {
90+
int months = this.months + that.months;
91+
long microseconds = this.microseconds + that.microseconds;
92+
return new Interval(months, microseconds);
93+
}
94+
95+
public Interval subtract(Interval that) {
96+
int months = this.months - that.months;
97+
long microseconds = this.microseconds - that.microseconds;
98+
return new Interval(months, microseconds);
99+
}
100+
89101
@Override
90102
public boolean equals(Object other) {
91103
if (this == other) return true;

unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,44 @@ public void fromStringTest() {
9595
assertEquals(Interval.fromString(input), null);
9696
}
9797

98+
@Test
99+
public void addTest() {
100+
String input = "interval 3 month 1 hour";
101+
String input2 = "interval 2 month 100 hour";
102+
103+
Interval interval = Interval.fromString(input);
104+
Interval interval2 = Interval.fromString(input2);
105+
106+
assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));
107+
108+
input = "interval -10 month -81 hour";
109+
input2 = "interval 75 month 200 hour";
110+
111+
interval = Interval.fromString(input);
112+
interval2 = Interval.fromString(input2);
113+
114+
assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
115+
}
116+
117+
@Test
118+
public void subtractTest() {
119+
String input = "interval 3 month 1 hour";
120+
String input2 = "interval 2 month 100 hour";
121+
122+
Interval interval = Interval.fromString(input);
123+
Interval interval2 = Interval.fromString(input2);
124+
125+
assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));
126+
127+
input = "interval -10 month -81 hour";
128+
input2 = "interval 75 month 200 hour";
129+
130+
interval = Interval.fromString(input);
131+
interval2 = Interval.fromString(input2);
132+
133+
assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
134+
}
135+
98136
private void testSingleUnit(String unit, int number, int months, long microseconds) {
99137
String input1 = "interval " + number + " " + unit;
100138
String input2 = "interval " + number + " " + unit + "s";

0 commit comments

Comments
 (0)