diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index c1d72f9b58a4..c64aeff3c238 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -256,7 +256,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { val castedLit = lit.dataType match { case CalendarIntervalType => val calendarInterval = lit.value.asInstanceOf[CalendarInterval] - if (calendarInterval.months > 0) { + if (calendarInterval.months != 0) { invalid = true logWarning( s"Failed to extract state value watermark from condition $exprToCollectFrom " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index b9ec933f3149..d3aadad12052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -103,7 +103,7 @@ object TimeWindow { */ private def getIntervalInMicroSeconds(interval: String): Long = { val cal = IntervalUtils.fromString(interval) - if (cal.months > 0) { + if (cal.months != 0) { throw new IllegalArgumentException( s"Intervals greater than a month is not supported ($interval).") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 8441c2c481ec..b6bf7cd85d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import java.util.concurrent.TimeUnit import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.unsafe.types.CalendarInterval @@ -28,9 +29,7 @@ object EventTimeWatermark { val delayKey = "spark.watermarkDelayMs" def getDelayMs(delay: CalendarInterval): Long = { - // We define month as `31 days` to simplify calculation. - val millisPerMonth = TimeUnit.MICROSECONDS.toMillis(CalendarInterval.MICROS_PER_DAY) * 31 - delay.milliseconds + delay.months * millisPerMonth + IntervalUtils.getDuration(delay, TimeUnit.MILLISECONDS) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index f55b0545ee9c..23b9e3f4404c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.util.regex.Pattern +import java.util.concurrent.TimeUnit import scala.util.control.NonFatal @@ -317,4 +317,42 @@ object IntervalUtils { "Interval string does not match second-nano format of ss.nnnnnnnnn") } } + + /** + * Gets interval duration + * + * @param interval The interval to get duration + * @param targetUnit Time units of the result + * @param daysPerMonth The number of days per one month. The default value is 31 days + * per month. This value was taken as the default because it is used + * in Structured Streaming for watermark calculations. Having 31 days + * per month, we can guarantee that events are not dropped before + * the end of any month (February with 29 days or January with 31 days). + * @return Duration in the specified time units + */ + def getDuration( + interval: CalendarInterval, + targetUnit: TimeUnit, + daysPerMonth: Int = 31): Long = { + val monthsDuration = Math.multiplyExact( + daysPerMonth * DateTimeUtils.MICROS_PER_DAY, + interval.months) + val result = Math.addExact(interval.microseconds, monthsDuration) + targetUnit.convert(result, TimeUnit.MICROSECONDS) + } + + /** + * Checks the interval is negative + * + * @param interval The checked interval + * @param daysPerMonth The number of days per one month. The default value is 31 days + * per month. This value was taken as the default because it is used + * in Structured Streaming for watermark calculations. Having 31 days + * per month, we can guarantee that events are not dropped before + * the end of any month (February with 29 days or January with 31 days). + * @return true if duration of the given interval is less than 0 otherwise false + */ + def isNegative(interval: CalendarInterval, daysPerMonth: Int = 31): Boolean = { + getDuration(interval, TimeUnit.MICROSECONDS, daysPerMonth) < 0 + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 9addc396b8d3..22944035f31d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.util.concurrent.TimeUnit + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.IntervalUtils.{fromDayTimeString, fromString, fromYearMonthString} import org.apache.spark.unsafe.types.CalendarInterval @@ -148,4 +150,38 @@ class IntervalUtilsSuite extends SparkFunSuite { assert(e.getMessage.contains("Cannot support (interval")) } } + + test("interval duration") { + def duration(s: String, unit: TimeUnit, daysPerMonth: Int): Long = { + IntervalUtils.getDuration(fromString(s), unit, daysPerMonth) + } + + assert(duration("0 seconds", TimeUnit.MILLISECONDS, 31) === 0) + assert(duration("1 month", TimeUnit.DAYS, 31) === 31) + assert(duration("1 microsecond", TimeUnit.MICROSECONDS, 30) === 1) + assert(duration("1 month -30 days", TimeUnit.DAYS, 31) === 1) + + try { + duration(Integer.MAX_VALUE + " month", TimeUnit.SECONDS, 31) + fail("Expected to throw an exception for the invalid input") + } catch { + case e: ArithmeticException => + assert(e.getMessage.contains("overflow")) + } + } + + test("negative interval") { + def isNegative(s: String, daysPerMonth: Int): Boolean = { + IntervalUtils.isNegative(fromString(s), daysPerMonth) + } + + assert(isNegative("-1 months", 28)) + assert(isNegative("-1 microsecond", 30)) + assert(isNegative("-1 month 30 days", 31)) + assert(isNegative("2 months -61 days", 30)) + assert(isNegative("-1 year -2 seconds", 30)) + assert(!isNegative("0 months", 28)) + assert(!isNegative("1 year -360 days", 31)) + assert(!isNegative("-1 year 380 days", 31)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5f6e0a82be4c..a88fd5111221 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -732,7 +732,7 @@ class Dataset[T] private[sql]( s"Unable to parse time delay '$delayThreshold'", cause = Some(e)) } - require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, + require(!IntervalUtils.isNegative(parsedDelay), s"delay threshold ($delayThreshold) should not be negative.") EliminateEventTimeWatermark( EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index d191a79187f2..aac5da8104a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -161,12 +161,11 @@ private[sql] class GroupStateImpl[S] private( private def parseDuration(duration: String): Long = { val cal = IntervalUtils.fromString(duration) - if (cal.milliseconds < 0 || cal.months < 0) { - throw new IllegalArgumentException(s"Provided duration ($duration) is not positive") + if (IntervalUtils.isNegative(cal)) { + throw new IllegalArgumentException(s"Provided duration ($duration) is negative") } - val millisPerMonth = TimeUnit.MICROSECONDS.toMillis(CalendarInterval.MICROS_PER_DAY) * 31 - cal.milliseconds + cal.months * millisPerMonth + IntervalUtils.getDuration(cal, TimeUnit.MILLISECONDS) } private def checkTimeoutTimestampAllowed(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index daa70a12ba0e..48113d1c18b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -31,7 +31,7 @@ private object Triggers { def convert(interval: String): Long = { val cal = IntervalUtils.fromString(interval) - if (cal.months > 0) { + if (cal.months != 0) { throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") } TimeUnit.MICROSECONDS.toMillis(cal.microseconds) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index df7e9217f914..d36c64f61a72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -125,6 +125,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + state.setTimeoutDuration("-1 month 31 days 1 second") + assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration(500) assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -225,8 +227,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { testIllegalTimeout { state.setTimeoutDuration("-1 month") } + testIllegalTimeout { - state.setTimeoutDuration("1 month -1 day") + state.setTimeoutDuration("1 month -31 day") } state = GroupStateImpl.createForStreaming( @@ -241,7 +244,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { state.setTimeoutTimestamp(10000, "-1 month") } testIllegalTimeout { - state.setTimeoutTimestamp(10000, "1 month -1 day") + state.setTimeoutTimestamp(10000, "1 month -32 day") } testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) @@ -253,7 +256,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } testIllegalTimeout { - state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") + state.setTimeoutTimestamp(new Date(-10000), "1 month -32 day") } }