Skip to content

Commit 863670e

Browse files
committed
address comments
1 parent 92ad086 commit 863670e

File tree

5 files changed

+72
-69
lines changed

5 files changed

+72
-69
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -159,65 +159,60 @@ public static CalendarInterval fromDayTimeString(String s, String from, String t
159159
return result;
160160
}
161161

162-
public static CalendarInterval fromSingleUnitString(String unit, String s)
162+
public static CalendarInterval fromUnitString(String[] units, String[] values)
163163
throws IllegalArgumentException {
164+
assert units.length == values.length;
165+
int months = 0;
166+
long microseconds = 0;
164167

165-
CalendarInterval result = null;
166-
if (s == null) {
167-
throw new IllegalArgumentException(String.format("Interval %s string was null", unit));
168-
}
169-
s = s.trim();
170-
try {
171-
switch (unit) {
172-
case "year":
173-
int year = (int) toLongWithRange("year", s,
174-
Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12);
175-
result = new CalendarInterval(year * 12, 0L);
176-
break;
177-
case "month":
178-
int month = (int) toLongWithRange("month", s,
179-
Integer.MIN_VALUE, Integer.MAX_VALUE);
180-
result = new CalendarInterval(month, 0L);
181-
break;
182-
case "week":
183-
long week = toLongWithRange("week", s,
184-
Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK);
185-
result = new CalendarInterval(0, week * MICROS_PER_WEEK);
186-
break;
187-
case "day":
188-
long day = toLongWithRange("day", s,
189-
Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
190-
result = new CalendarInterval(0, day * MICROS_PER_DAY);
191-
break;
192-
case "hour":
193-
long hour = toLongWithRange("hour", s,
194-
Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR);
195-
result = new CalendarInterval(0, hour * MICROS_PER_HOUR);
196-
break;
197-
case "minute":
198-
long minute = toLongWithRange("minute", s,
199-
Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE);
200-
result = new CalendarInterval(0, minute * MICROS_PER_MINUTE);
201-
break;
202-
case "second": {
203-
long micros = parseSecondNano(s);
204-
result = new CalendarInterval(0, micros);
205-
break;
168+
for (int i = 0; i < units.length; i++) {
169+
try {
170+
String value = values[i].trim();
171+
switch (units[i]) {
172+
case "year":
173+
months = Math.addExact(months, Math.multiplyExact(Integer.parseInt(value), 12));
174+
break;
175+
case "month":
176+
months = Math.addExact(months, Integer.parseInt(value));
177+
break;
178+
case "week":
179+
microseconds = Math.addExact(
180+
microseconds,
181+
Math.multiplyExact(Long.parseLong(value), MICROS_PER_WEEK));
182+
break;
183+
case "day":
184+
microseconds = Math.addExact(
185+
microseconds,
186+
Math.multiplyExact(Long.parseLong(value), MICROS_PER_DAY));
187+
break;
188+
case "hour":
189+
microseconds = Math.addExact(
190+
microseconds,
191+
Math.multiplyExact(Long.parseLong(value), MICROS_PER_HOUR));
192+
break;
193+
case "minute":
194+
microseconds = Math.addExact(
195+
microseconds,
196+
Math.multiplyExact(Long.parseLong(value), MICROS_PER_MINUTE));
197+
break;
198+
case "second": {
199+
microseconds = Math.addExact(microseconds, parseSecondNano(value));
200+
break;
201+
}
202+
case "millisecond":
203+
microseconds = Math.addExact(
204+
microseconds,
205+
Math.multiplyExact(Long.parseLong(value), MICROS_PER_MILLI));
206+
break;
207+
case "microsecond":
208+
microseconds = Math.addExact(microseconds, Long.parseLong(value));
209+
break;
206210
}
207-
case "millisecond":
208-
long millisecond = toLongWithRange("millisecond", s,
209-
Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI);
210-
result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI);
211-
break;
212-
case "microsecond":
213-
long micros = Long.parseLong(s);
214-
result = new CalendarInterval(0, micros);
215-
break;
211+
} catch (Exception e) {
212+
throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);
216213
}
217-
} catch (Exception e) {
218-
throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);
219214
}
220-
return result;
215+
return new CalendarInterval(months, microseconds);
221216
}
222217

223218
/**

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ singleTableSchema
8080
;
8181

8282
singleInterval
83-
: INTERVAL? intervalField+ EOF
83+
: INTERVAL? (intervalValue intervalUnit)+ EOF
8484
;
8585

8686
statement

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
103103

104104
override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
105105
withOrigin(ctx) {
106-
val intervals = ctx.intervalField.asScala.map(visitIntervalField)
107-
validate(intervals.nonEmpty,
108-
"at least one time unit should be given for interval literal", ctx)
109-
intervals.reduce(_.add(_))
106+
val units = ctx.intervalUnit().asScala.map {
107+
u => normalizeInternalUnit(u.getText.toLowerCase(Locale.ROOT))
108+
}.toArray
109+
val values = ctx.intervalValue().asScala.map(getIntervalValue).toArray
110+
CalendarInterval.fromUnitString(units, values)
110111
}
111112
}
112113

@@ -1940,18 +1941,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
19401941
*/
19411942
override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
19421943
import ctx._
1943-
val s = if (value.STRING() != null) {
1944-
string(value.STRING())
1945-
} else {
1946-
value.getText
1947-
}
1944+
val s = getIntervalValue(value)
19481945
try {
19491946
val unitText = unit.getText.toLowerCase(Locale.ROOT)
19501947
val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match {
19511948
case (u, None) =>
1952-
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
1953-
val unit = if (u.endsWith("s")) u.substring(0, u.length - 1) else u
1954-
CalendarInterval.fromSingleUnitString(unit, s)
1949+
CalendarInterval.fromUnitString(Array(normalizeInternalUnit(u)), Array(s))
19551950
case ("year", Some("month")) =>
19561951
CalendarInterval.fromYearMonthString(s)
19571952
case ("day", Some("hour")) =>
@@ -1980,6 +1975,19 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
19801975
}
19811976
}
19821977

1978+
private def getIntervalValue(value: IntervalValueContext): String = {
1979+
if (value.STRING() != null) {
1980+
string(value.STRING())
1981+
} else {
1982+
value.getText
1983+
}
1984+
}
1985+
1986+
// Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
1987+
private def normalizeInternalUnit(s: String): String = {
1988+
if (s.endsWith("s")) s.substring(0, s.length - 1) else s
1989+
}
1990+
19831991
/* ********************************************************************************************
19841992
* DataType parsing
19851993
* ******************************************************************************************** */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ trait ParserInterface {
7575
/**
7676
* Parse a string to a [[CalendarInterval]].
7777
*/
78-
@throws[ParseException]("Text cannot be parsed to a DataType")
78+
@throws[ParseException]("Text cannot be parsed to an interval")
7979
def parseInterval(sqlText: String): CalendarInterval
8080
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ class ExpressionParserSuite extends AnalysisTest {
597597
"microsecond")
598598

599599
def intervalLiteral(u: String, s: String): Literal = {
600-
Literal(CalendarInterval.fromSingleUnitString(u, s))
600+
Literal(CalendarInterval.fromUnitString(Array(u), Array(s)))
601601
}
602602

603603
test("intervals") {

0 commit comments

Comments
 (0)