diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 0983810c9ad1a..c451eb2b877da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -828,8 +828,9 @@ object TypeCoercion { /** * 1. Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub. - * 2. Turns Add/Subtract of DateType/IntegerType and IntegerType/DateType - * to DateAdd/DateSub/DateDiff. + * 2. Turns Add/Subtract of TimestampType/DateType/IntegerType + * and TimestampType/IntegerType/DateType to DateAdd/DateSub/SubtractDates and + * to SubtractTimestamps. */ object DateTimeOperations extends Rule[LogicalPlan] { @@ -849,12 +850,14 @@ object TypeCoercion { case Add(l @ DateType(), r @ IntegerType()) => DateAdd(l, r) case Add(l @ IntegerType(), r @ DateType()) => DateAdd(r, l) case Subtract(l @ DateType(), r @ IntegerType()) => DateSub(l, r) - case Subtract(l @ DateType(), r @ DateType()) => DateDiff(l, r) - case Subtract(l @ TimestampType(), r @ TimestampType()) => TimestampDiff(l, r) + case Subtract(l @ DateType(), r @ DateType()) => + if (SQLConf.get.usePostgreSQLDialect) DateDiff(l, r) else SubtractDates(l, r) + case Subtract(l @ TimestampType(), r @ TimestampType()) => + SubtractTimestamps(l, r) case Subtract(l @ TimestampType(), r @ DateType()) => - TimestampDiff(l, Cast(r, TimestampType)) + SubtractTimestamps(l, Cast(r, TimestampType)) case Subtract(l @ DateType(), r @ TimestampType()) => - TimestampDiff(Cast(l, TimestampType), r) + SubtractTimestamps(Cast(l, TimestampType), r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 5aea884ad5003..cddd8c9bd61b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -2099,7 +2099,7 @@ case class DatePart(field: Expression, source: Expression, child: Expression) * is set to 0 and the `microseconds` field is initialized to the microsecond difference * between the given timestamps. */ -case class TimestampDiff(endTimestamp: Expression, startTimestamp: Expression) +case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = endTimestamp @@ -2116,3 +2116,25 @@ case class TimestampDiff(endTimestamp: Expression, startTimestamp: Expression) s"new org.apache.spark.unsafe.types.CalendarInterval(0, $end - $start)") } } + +/** + * Returns the interval from the `left` date (inclusive) to the `right` date (exclusive). + */ +case class SubtractDates(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) + override def dataType: DataType = CalendarIntervalType + + override def nullSafeEval(leftDays: Any, rightDays: Any): Any = { + DateTimeUtils.subtractDates(leftDays.asInstanceOf[Int], rightDays.asInstanceOf[Int]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, (leftDays, rightDays) => { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + s"$dtu.subtractDates($leftDays, $rightDays)" + }) + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 34e8012106bbe..088876921dccd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -27,7 +27,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * Helper functions for converting between internal and external date and time representations. @@ -950,4 +950,20 @@ object DateTimeUtils { None } } + + /** + * Subtracts two dates. + * @param endDate - the end date, exclusive + * @param startDate - the start date, inclusive + * @return an interval between two dates. The interval can be negative + * if the end date is before the start date. + */ + def subtractDates(endDate: SQLDate, startDate: SQLDate): CalendarInterval = { + val period = Period.between( + LocalDate.ofEpochDay(startDate), + LocalDate.ofEpochDay(endDate)) + val months = period.getMonths + 12 * period.getYears + val microseconds = period.getDays * MICROS_PER_DAY + new CalendarInterval(months, microseconds) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index f60e0f2bfee6a..4f9e4ec0201dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1430,13 +1430,13 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(dateTimeOperations, Add(date, intValue), DateAdd(date, intValue)) ruleTest(dateTimeOperations, Add(intValue, date), DateAdd(date, intValue)) ruleTest(dateTimeOperations, Subtract(date, intValue), DateSub(date, intValue)) - ruleTest(dateTimeOperations, Subtract(date, date), DateDiff(date, date)) + ruleTest(dateTimeOperations, Subtract(date, date), SubtractDates(date, date)) ruleTest(dateTimeOperations, Subtract(timestamp, timestamp), - TimestampDiff(timestamp, timestamp)) + SubtractTimestamps(timestamp, timestamp)) ruleTest(dateTimeOperations, Subtract(timestamp, date), - TimestampDiff(timestamp, Cast(date, TimestampType))) + SubtractTimestamps(timestamp, Cast(date, TimestampType))) ruleTest(dateTimeOperations, Subtract(date, timestamp), - TimestampDiff(Cast(date, TimestampType), timestamp)) + SubtractTimestamps(Cast(date, TimestampType), timestamp)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 8680a15ee1cd7..e893e863b3675 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.time.{Instant, LocalDateTime, ZoneId, ZoneOffset} +import java.time.{Instant, LocalDate, LocalDateTime, ZoneId, ZoneOffset} import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit._ @@ -1072,19 +1072,39 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("timestamps difference") { val end = Instant.parse("2019-10-04T11:04:01.123456Z") - checkEvaluation(TimestampDiff(Literal(end), Literal(end)), + checkEvaluation(SubtractTimestamps(Literal(end), Literal(end)), new CalendarInterval(0, 0)) - checkEvaluation(TimestampDiff(Literal(end), Literal(Instant.EPOCH)), + checkEvaluation(SubtractTimestamps(Literal(end), Literal(Instant.EPOCH)), CalendarInterval.fromString("interval 18173 days " + "11 hours 4 minutes 1 seconds 123 milliseconds 456 microseconds")) - checkEvaluation(TimestampDiff(Literal(Instant.EPOCH), Literal(end)), + checkEvaluation(SubtractTimestamps(Literal(Instant.EPOCH), Literal(end)), CalendarInterval.fromString("interval -18173 days " + "-11 hours -4 minutes -1 seconds -123 milliseconds -456 microseconds")) checkEvaluation( - TimestampDiff( + SubtractTimestamps( Literal(Instant.parse("9999-12-31T23:59:59.999999Z")), Literal(Instant.parse("0001-01-01T00:00:00Z"))), CalendarInterval.fromString("interval 521722 weeks 4 days " + "23 hours 59 minutes 59 seconds 999 milliseconds 999 microseconds")) } + + test("subtract dates") { + val end = LocalDate.of(2019, 10, 5) + checkEvaluation(SubtractDates(Literal(end), Literal(end)), + new CalendarInterval(0, 0)) + checkEvaluation(SubtractDates(Literal(end.plusDays(1)), Literal(end)), + CalendarInterval.fromString("interval 1 days")) + checkEvaluation(SubtractDates(Literal(end.minusDays(1)), Literal(end)), + CalendarInterval.fromString("interval -1 days")) + val epochDate = Literal(LocalDate.ofEpochDay(0)) + checkEvaluation(SubtractDates(Literal(end), epochDate), + CalendarInterval.fromString("interval 49 years 9 months 4 days")) + checkEvaluation(SubtractDates(epochDate, Literal(end)), + CalendarInterval.fromString("interval -49 years -9 months -4 days")) + checkEvaluation( + SubtractDates( + Literal(LocalDate.of(10000, 1, 1)), + Literal(LocalDate.of(1, 1, 1))), + CalendarInterval.fromString("interval 9999 years")) + } } diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index c3c131d22d0fb..0f4036cad6125 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -126,15 +126,15 @@ struct -- !query 14 select date '2001-10-01' - date '2001-09-28' -- !query 14 schema -struct +struct -- !query 14 output -3 +interval 3 days -- !query 15 select date'2020-01-01' - timestamp'2019-10-06 10:11:12.345678' -- !query 15 schema -struct +struct -- !query 15 output interval 12 weeks 2 days 14 hours 48 minutes 47 seconds 654 milliseconds 322 microseconds @@ -142,6 +142,6 @@ interval 12 weeks 2 days 14 hours 48 minutes 47 seconds 654 milliseconds 322 mic -- !query 16 select timestamp'2019-10-06 10:11:12.345678' - date'2020-01-01' -- !query 16 schema -struct +struct -- !query 16 output interval -12 weeks -2 days -14 hours -48 minutes -47 seconds -654 milliseconds -322 microseconds