From d3e9d0e7a834075736273a791a8f55f7408c6cbe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 01:12:11 +0800 Subject: [PATCH 1/5] Add add and subtract expressions for IntervalType. --- .../sql/catalyst/expressions/arithmetic.scala | 29 ++++++++++++-- .../expressions/codegen/CodeGenerator.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 3 +- .../spark/sql/catalyst/util/TypeUtils.scala | 8 ++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +++++++ .../apache/spark/unsafe/types/Interval.java | 12 ++++++ .../spark/unsafe/types/IntervalSuite.java | 38 +++++++++++++++++++ 7 files changed, 101 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8476af4a5d8d..7ce654b9e762 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.Interval abstract class UnaryArithmetic extends UnaryExpression { self: Product => @@ -87,6 +88,10 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") + /** Name of the function for this expression on a [[Interval]] type. */ + def intervalMethod: String = + sys.error("BinaryArithmetics must override either intervalMethod or genCode") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") @@ -94,6 +99,8 @@ abstract class BinaryArithmetic extends BinaryOperator { case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$intervalMethod($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -106,31 +113,45 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" override def decimalMethod: String = "$plus" + override def intervalMethod: String = "add" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) + TypeUtils.checkForNumericAndIntervalExpr(t, "operator " + symbol) private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + } else { + numeric.plus(input1, input2) + } + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" override def decimalMethod: String = "$minus" + override def intervalMethod: String = "subtract" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) + TypeUtils.checkForNumericAndIntervalExpr(t, "operator " + symbol) private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + } else { + numeric.minus(input1, input2) + } + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4e..36ac7a0fadca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -26,7 +26,7 @@ import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ // These classes are here to avoid issues with serialization and integration with quasiquotes. @@ -57,6 +57,7 @@ class CodeGenContext { val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() val stringType: String = classOf[UTF8String].getName + val intervalType: String = classOf[Interval].getName val decimalType: String = classOf[Decimal].getName final val JAVA_BOOLEAN = "boolean" @@ -127,6 +128,7 @@ class CodeGenContext { case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType + case IntervalType => intervalType case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a7a7ae44003..e1fdb29541fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types._ object Literal { def apply(v: Any): Literal = v match { @@ -42,6 +42,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) + case i: Interval => Literal(i, IntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 3148309a2166..c68eee7f9d3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -32,6 +32,14 @@ object TypeUtils { } } + def checkForNumericAndIntervalExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t.isInstanceOf[IntervalType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric or interval types, not $t") + } + } + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[IntegralType] || t == NullType) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 231440892bf0..7edbc6051844 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1492,4 +1492,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Currently we don't yet support nanosecond checkIntervalParseError("select interval 23 nanosecond") } + + test("SPARK-8945: add and subtract expressions for interval type") { + import org.apache.spark.unsafe.types.Interval + + val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") + checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + + checkAnswer(df.select(df("i") + new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + + checkAnswer(df.select(df("i") - new Interval(2, 123)), + Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index eb7475e9df86..ec3f0e263e58 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -86,6 +86,18 @@ public Interval(int months, long microseconds) { this.microseconds = microseconds; } + public Interval add(Interval that) { + int months = this.months + that.months; + long microseconds = this.microseconds + that.microseconds; + return new Interval(months, microseconds); + } + + public Interval subtract(Interval that) { + int months = this.months - that.months; + long microseconds = this.microseconds - that.microseconds; + return new Interval(months, microseconds); + } + @Override public boolean equals(Object other) { if (this == other) return true; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 44a949a371f2..4e96f325159a 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -95,6 +95,44 @@ public void fromStringTest() { assertEquals(Interval.fromString(input), null); } + @Test + public void addTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + Interval interval = Interval.fromString(input); + Interval interval2 = Interval.fromString(input2); + + assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = Interval.fromString(input); + interval2 = Interval.fromString(input2); + + assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR)); + } + + @Test + public void subtractTest() { + String input = "interval 3 month 1 hour"; + String input2 = "interval 2 month 100 hour"; + + Interval interval = Interval.fromString(input); + Interval interval2 = Interval.fromString(input2); + + assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR)); + + input = "interval -10 month -81 hour"; + input2 = "interval 75 month 200 hour"; + + interval = Interval.fromString(input); + interval2 = Interval.fromString(input2); + + assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR)); + } + private void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; From acfe1ab3c349faf8cb279d1d8ff708f7e0ef50b3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 10:28:23 +0800 Subject: [PATCH 2/5] Fix scala style. --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 36ac7a0fadca..1502f304fdfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -128,7 +128,7 @@ class CodeGenContext { case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case IntervalType => intervalType + case IntervalType => intervalType case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" From 83ec1292f00610eb62649ac6eb0879e6db85c1b2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 13:17:17 +0800 Subject: [PATCH 3/5] Remove intervalMethod. --- .../spark/sql/catalyst/expressions/arithmetic.scala | 8 +------- .../analysis/ExpressionTypeCheckingSuite.scala | 4 ++-- .../org/apache/spark/unsafe/types/Interval.java | 13 +++++++++++++ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7ce654b9e762..afa3a416fbe4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -88,10 +88,6 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") - /** Name of the function for this expression on a [[Interval]] type. */ - def intervalMethod: String = - sys.error("BinaryArithmetics must override either intervalMethod or genCode") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") @@ -100,7 +96,7 @@ abstract class BinaryArithmetic extends BinaryOperator { defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") case IntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$intervalMethod($eval2)") + defineCodeGen(ctx, ev, (eval1, eval2) => s"""$eval1.doOp($eval2, "$symbol")""") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -113,7 +109,6 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" override def decimalMethod: String = "$plus" - override def intervalMethod: String = "add" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -135,7 +130,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" override def decimalMethod: String = "$minus" - override def intervalMethod: String = "subtract" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea..81ae6ea317da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric or interval types") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric or interval types") assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index ec3f0e263e58..abee7f9ae62a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -86,6 +86,19 @@ public Interval(int months, long microseconds) { this.microseconds = microseconds; } + public Interval doOp(Interval that, String op) { + Interval opRet = null; + switch (op) { + case "+": + opRet = add(that); + break; + case "-": + opRet = subtract(that); + break; + } + return opRet; + } + public Interval add(Interval that) { int months = this.months + that.months; long microseconds = this.microseconds + that.microseconds; From dbe39061719e2d8ccf0006fffb487a780232f41c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 15:43:25 +0800 Subject: [PATCH 4/5] For comments. --- .../sql/catalyst/expressions/arithmetic.scala | 45 +++++++++++++++---- .../spark/sql/types/AbstractDataType.scala | 8 ++++ .../ExpressionTypeCheckingSuite.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 ++ .../apache/spark/unsafe/types/Interval.java | 17 ++----- 5 files changed, 54 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index dfa2084b2f4c..9cb77d394a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.Interval case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -37,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } - protected override def nullSafeEval(input: Any): Any = numeric.negate(input) + protected override def nullSafeEval(input: Any): Any = { + if (dataType.isInstanceOf[IntervalType]) { + input.asInstanceOf[Interval].negate() + } else { + numeric.negate(input) + } + } } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) override def dataType: DataType = child.dataType @@ -85,8 +92,6 @@ abstract class BinaryArithmetic extends BinaryOperator { case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"""$eval1.doOp($eval2, "$symbol")""") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -98,8 +103,7 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = NumericType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntervalType) + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "+" override def decimalMethod: String = "$plus" @@ -116,12 +120,23 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { - override def inputType: AbstractDataType = NumericType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntervalType) + override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "-" override def decimalMethod: String = "$minus" @@ -138,6 +153,18 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + case ByteType | ShortType => + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case IntervalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index f5715f7a829f..b14ea533f5db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -114,6 +114,14 @@ private[sql] object TypeCollection { BooleanType, ByteType, ShortType, IntegerType, LongType) + /** + * Types that include numeric types and interval type. They are only used in unary_minus, + * unary_positive, add and subtract operations. + */ + val NumericAndInterval = TypeCollection( + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType, IntervalType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ebc56fbcfa9a..a51e42e5b675 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "argument 1 is expected to be of type numeric") + assertError(UnaryMinus('stringField), "argument 1 is expected to be of type (tinyint or " + + "smallint or int or bigint or float or double or decimal or interval") assertError(Abs('stringField), "argument 1 is expected to be of type numeric") assertError(BitwiseNot('stringField), "argument 1 is expected to be of type (boolean " + "or tinyint or smallint or int or bigint)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7edbc6051844..5b8b70ed5ae1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1504,5 +1504,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(df.select(df("i") - new Interval(2, 123)), Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + + // unary minus + checkAnswer(df.select(-df("i")), + Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index abee7f9ae62a..7d9f9e9c8cae 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -86,19 +86,6 @@ public Interval(int months, long microseconds) { this.microseconds = microseconds; } - public Interval doOp(Interval that, String op) { - Interval opRet = null; - switch (op) { - case "+": - opRet = add(that); - break; - case "-": - opRet = subtract(that); - break; - } - return opRet; - } - public Interval add(Interval that) { int months = this.months + that.months; long microseconds = this.microseconds + that.microseconds; @@ -111,6 +98,10 @@ public Interval subtract(Interval that) { return new Interval(months, microseconds); } + public Interval negate() { + return new Interval(-this.months, -this.microseconds); + } + @Override public boolean equals(Object other) { if (this == other) return true; From 5abae28c48cb2c4fa908621121d19416d51ace9b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Jul 2015 16:08:26 +0800 Subject: [PATCH 5/5] For comments. --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 6 ++---- .../scala/org/apache/spark/sql/types/AbstractDataType.scala | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9cb77d394a45..0ee2f5e57502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -106,7 +106,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "+" - override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -123,7 +122,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") @@ -139,7 +138,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def inputType: AbstractDataType = TypeCollection.NumericAndInterval override def symbol: String = "-" - override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -156,7 +154,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index b14ea533f5db..8ac8ac5b0bb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -120,7 +120,8 @@ private[sql] object TypeCollection { */ val NumericAndInterval = TypeCollection( ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType, IntervalType) + FloatType, DoubleType, DecimalType, + IntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)