Skip to content

Commit a476c5a

Browse files
committed
address comments from davies
1 parent d44ea5f commit a476c5a

File tree

2 files changed

+89
-30
lines changed

2 files changed

+89
-30
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
277277
*/
278278
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
279279

280+
// Implicit casting of spark will accept string in both date and timestamp format, as
281+
// well as TimestampType.
280282
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
281283

282284
override def dataType: DataType = DateType
@@ -361,37 +363,76 @@ case class Trunc(date: Expression, format: Expression)
361363
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
362364
override def dataType: DataType = DateType
363365

364-
override def nullSafeEval(d: Any, fmt: Any): Any = {
365-
val minItem = DateTimeUtils.getFmt(fmt.asInstanceOf[UTF8String])
366-
if (minItem == -1) {
367-
// unknown format
368-
null
366+
lazy val constFmt = format.eval().asInstanceOf[UTF8String]
367+
368+
override def eval(input: InternalRow): Any = {
369+
if (format.foldable) {
370+
val minItem = DateTimeUtils.getFmt(constFmt)
371+
if (minItem == -1) {
372+
// unknown format
373+
null
374+
} else {
375+
val d = date.eval(input)
376+
if (d == null) {
377+
null
378+
} else {
379+
DateTimeUtils.dateTrunc(d.asInstanceOf[Int], minItem)
380+
}
381+
}
369382
} else {
370-
val days = d.asInstanceOf[Int]
371-
if (minItem == Calendar.YEAR) {
372-
days - DateTimeUtils.getDayInYear(days) + 1
383+
val fmt = format.eval(input).asInstanceOf[UTF8String]
384+
val d = date.eval(input)
385+
if (d == null) {
386+
null
373387
} else {
374-
// trunc to MONTH
375-
days - DateTimeUtils.getDayOfMonth(days) + 1
388+
val minItem = DateTimeUtils.getFmt(fmt)
389+
if (minItem == -1) {
390+
// unknown format
391+
null
392+
} else {
393+
DateTimeUtils.dateTrunc(d.asInstanceOf[Int], minItem)
394+
}
376395
}
377396
}
378397
}
379398

380399
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
381-
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
382-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
383-
val form = ctx.freshName("form")
384-
s"""
385-
int $form = $dtu.getFmt($fmt);
386-
if ($form == ${Calendar.YEAR}) {
387-
${ev.primitive} = $dateVal - $dtu.getDayInYear($dateVal) + 1;
388-
} else if ($form == ${Calendar.MONTH}) {
389-
${ev.primitive} = $dateVal - $dtu.getDayInYear($dateVal) + 1;
390-
} else {
391-
${ev.isNull} = true;
392-
}
393-
"""
394-
})
400+
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
401+
if (date.foldable) {
402+
val d = date.gen(ctx)
403+
val minItem = DateTimeUtils.getFmt(constFmt)
404+
if (d == null || minItem == -1) {
405+
s"""
406+
boolean ${ev.isNull} = true;
407+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
408+
"""
409+
} else {
410+
s"""
411+
${d.code}
412+
boolean ${ev.isNull} = ${d.isNull};
413+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
414+
if (!${ev.isNull}) {
415+
if ($minItem == -1) {
416+
${ev.isNull} = true;
417+
} else {
418+
${ev.primitive} = $dtu.dateTrunc(${d.primitive}, $minItem);
419+
}
420+
}
421+
"""
422+
}
423+
} else {
424+
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
425+
val form = ctx.freshName("form")
426+
s"""
427+
int $form = $dtu.getFmt($fmt);
428+
if ($form == -1) {
429+
${ev.isNull} = true;
430+
} else {
431+
${ev.primitive} = $dtu.dateTrunc($dateVal, $form);
432+
}
433+
"""
434+
})
435+
}
395436
}
396437

397438
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -642,15 +642,33 @@ object DateTimeUtils {
642642
}
643643

644644
/**
645-
* Returns the truncate level, could be [[Calendar.MONTH]]/[[Calendar.YEAR]]/-1
645+
* Returns the trunc date from original date and trunc level.
646+
* Trunc level should be generated using `this.getFmt()`.
647+
*/
648+
def dateTrunc(d: Int, minItem: Int): Int = {
649+
if (minItem == 2) {
650+
// trunc to year
651+
d - DateTimeUtils.getDayInYear(d) + 1
652+
} else {
653+
// trunc to MONTH
654+
d - DateTimeUtils.getDayOfMonth(d) + 1
655+
}
656+
}
657+
658+
/**
659+
* Returns the truncate level, could be 1 for month, 2 for year, -1 for invalid/null
646660
* -1 means unsupported truncate level.
647661
*/
648662
def getFmt(string: UTF8String): Int = {
649-
val fmtString = string.toString.toUpperCase
650-
fmtString match {
651-
case "MON" | "MONTH" | "MM" => Calendar.MONTH
652-
case "YEAR"| "YYYY" | "YY" => Calendar.YEAR
653-
case _ => -1
663+
if (string == null) {
664+
-1
665+
} else {
666+
val fmtString = string.toString.toUpperCase
667+
fmtString match {
668+
case "MON" | "MONTH" | "MM" => 1
669+
case "YEAR" | "YYYY" | "YY" => 2
670+
case _ => -1
671+
}
654672
}
655673
}
656674
}

0 commit comments

Comments
 (0)