Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class Analyzer(
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
ResolveBinaryArithmetic(conf) ::
ResolveBinaryArithmetic ::
TypeCoercion.typeCoercionRules(conf) ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand All @@ -268,17 +268,21 @@ class Analyzer(
/**
* For [[Add]]:
* 1. if both side are interval, stays the same;
* 2. else if one side is interval, turns it to [[TimeAdd]];
* 3. else if one side is date, turns it to [[DateAdd]] ;
* 4. else stays the same.
* 2. else if one side is date and the other is interval,
* turns it to [[DateAddInterval]];
* 3. else if one side is interval, turns it to [[TimeAdd]];
* 4. else if one side is date, turns it to [[DateAdd]] ;
* 5. else stays the same.
*
* For [[Subtract]]:
* 1. if both side are interval, stays the same;
* 2. else if the right side is an interval, turns it to [[TimeSub]];
* 3. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 4. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 5. else if the left side is date, turns it to [[DateSub]];
* 6. else turns it to stays the same.
* 2. else if the left side is date and the right side is interval,
* turns it to [[DateAddInterval(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimeSub]];
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 6. else if the left side is date, turns it to [[DateSub]];
* 7. else turns it to stays the same.
*
* For [[Multiply]]:
* 1. If one side is interval, turns it to [[MultiplyInterval]];
Expand All @@ -288,19 +292,22 @@ class Analyzer(
* 1. If the left side is interval, turns it to [[DivideInterval]];
* 2. otherwise, stays the same.
*/
case class ResolveBinaryArithmetic(conf: SQLConf) extends Rule[LogicalPlan] {
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan => p.transformExpressionsUp {
case a @ Add(l, r) if a.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => a
case (DateType, CalendarIntervalType) => DateAddInterval(l, r)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) => DateAddInterval(r, l)
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r) if s.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => s
case (DateType, CalendarIntervalType) => DateAddInterval(l, UnaryMinus(r))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea, maybe we can remove TimeSub later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we can clean it up later

case (_, CalendarIntervalType) => Cast(TimeSub(l, r), l.dataType)
case (TimestampType, _) => SubtractTimestamps(l, r)
case (_, TimestampType) => SubtractTimestamps(l, r)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, Tim
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -1157,6 +1158,68 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
}
}

/**
* Adds date and an interval.
*
* When ansi mode is on, the microseconds part of interval needs to be 0, otherwise a runtime
* [[IllegalArgumentException]] will be raised.
* When ansi mode is off, if the microseconds part of interval is 0, we perform date + interval
* for better performance. if the microseconds part is not 0, then the date will be converted to a
* timestamp to add with the whole interval parts.
*/
case class DateAddInterval(
start: Expression,
interval: Expression,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left + $right"
override def sql: String = s"${left.sql} + ${right.sql}"
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, CalendarIntervalType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
if (ansiEnabled || itvl.microseconds == 0) {
DateTimeUtils.dateAddInterval(start.asInstanceOf[Int], itvl)
} else {
val startTs = DateTimeUtils.epochDaysToMicros(start.asInstanceOf[Int], zoneId)
val resultTs = DateTimeUtils.timestampAddInterval(
startTs, itvl.months, itvl.days, itvl.microseconds, zoneId)
DateTimeUtils.microsToEpochDays(resultTs, zoneId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (sd, i) => if (ansiEnabled) {
s"""${ev.value} = $dtu.dateAddInterval($sd, $i);"""
} else {
val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val startTs = ctx.freshName("startTs")
val resultTs = ctx.freshName("resultTs")
s"""
|if ($i.microseconds == 0) {
| ${ev.value} = $dtu.dateAddInterval($sd, $i);
|} else {
| long $startTs = $dtu.epochDaysToMicros($sd, $zid);
| long $resultTs =
| $dtu.timestampAddInterval($startTs, $i.months, $i.days, $i.microseconds, $zid);
| ${ev.value} = $dtu.microsToEpochDays($resultTs, $zid);
|}
|""".stripMargin
})
}

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
}

/**
* This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
* takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,22 @@ object DateTimeUtils {
instantToMicros(resultTimestamp.toInstant)
}

/**
* Add the date and the interval's months and days.
* Returns a date value, expressed in days since 1.1.1970.
*
* @throws DateTimeException if the result exceeds the supported date range
* @throws IllegalArgumentException if the interval has `microseconds` part
*/
def dateAddInterval(
start: SQLDate,
interval: CalendarInterval): SQLDate = {
require(interval.microseconds == 0,
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
val ld = LocalDate.ofEpochDay(start).plusMonths(interval.months).plusDays(interval.days)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the result depends on the order of plusMonths() and plusDays(). @yaooqinn Did you make the choice intentionally? I am asking you because adding days and months can be much cheaper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes,here we are follow the previous behavior of using timestamp + interval

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks. It would be nice to document such behavior of this function and timestampAddInterval somewhere. It is not obvious that we add month then days and then micros. The order could be opposite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, in snowflake internal '1 month 1 day' is different from internal '1 day 1 month'. We should at least document our own behavior.

Copy link
Member Author

@yaooqinn yaooqinn May 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make time for that PR.

localDateToDays(ld)
}

/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
* microseconds since 1.1.1970. If time1 is later than time2, the result is positive.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.{SparkFunSuite, SparkUpgradeException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -358,6 +359,40 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(DateAdd, DateType, IntegerType)
}

test("date add interval") {
val d = Date.valueOf("2016-02-28")
Seq("true", "false") foreach { flag =>
withSQLConf((SQLConf.ANSI_ENABLED.key, flag)) {
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(0, 1, 0))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29")))
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 0))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-29")))
checkEvaluation(DateAddInterval(Literal(d), Literal.create(null, CalendarIntervalType)),
null)
checkEvaluation(DateAddInterval(Literal.create(null, DateType),
Literal(new CalendarInterval(1, 1, 0))),
null)
}
}

withSQLConf((SQLConf.ANSI_ENABLED.key, "true")) {
checkExceptionInExpression[IllegalArgumentException](
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
"Cannot add hours, minutes or seconds, milliseconds, microseconds to a date")
}

withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-29")))
checkEvaluation(
DateAddInterval(Literal(d), Literal(new CalendarInterval(1, 1, 25 * MICROS_PER_HOUR))),
DateTimeUtils.fromJavaDate(Date.valueOf("2016-03-30")))
}
}

test("date_sub") {
checkEvaluation(
DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1.toByte)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,9 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite {
TimeSub('a, interval),
"`a` - INTERVAL '1 hours'"
)
checkSQL(
DateAddInterval('a, interval),
"`a` + INTERVAL '1 hours'"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {

Expand Down Expand Up @@ -391,6 +391,14 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
assert(dateAddMonths(input, -13) === days(1996, 1, 28))
}

test("date add interval with day precision") {
val input = days(1997, 2, 28, 10, 30)
assert(dateAddInterval(input, new CalendarInterval(36, 0, 0)) === days(2000, 2, 28))
assert(dateAddInterval(input, new CalendarInterval(36, 47, 0)) === days(2000, 4, 15))
assert(dateAddInterval(input, new CalendarInterval(-13, 0, 0)) === days(1996, 1, 28))
intercept[IllegalArgumentException](dateAddInterval(input, new CalendarInterval(36, 47, 1)))
}

test("timestamp add months") {
val ts1 = date(1997, 2, 28, 10, 30, 0)
val ts2 = date(2000, 2, 28, 10, 30, 0, 123000)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--IMPORT datetime.sql
Loading