Skip to content

Commit f6f070a

Browse files
committed
address comments from davies
1 parent 6a4cbb3 commit f6f070a

File tree

4 files changed

+168
-36
lines changed

4 files changed

+168
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 139 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,12 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
276276
* If the first parameter is a Date or Timestamp instead of String, we will ignore the
277277
* second parameter.
278278
*/
279-
case class UnixTimestamp(left: Expression, right: Expression)
279+
case class UnixTimestamp(timeExp: Expression, format: Expression)
280280
extends BinaryExpression with ExpectsInputTypes {
281281

282+
override def left: Expression = timeExp
283+
override def right: Expression = format
284+
282285
def this(time: Expression) = {
283286
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
284287
}
@@ -292,29 +295,66 @@ case class UnixTimestamp(left: Expression, right: Expression)
292295

293296
override def dataType: DataType = LongType
294297

295-
lazy val constFormat: String = right.eval().asInstanceOf[UTF8String].toString
296-
override def nullSafeEval(time: Any, format: Any): Any = {
297-
left.dataType match {
298-
case DateType =>
299-
DateTimeUtils.daysToMillis(time.asInstanceOf[Int]) / 1000L
300-
case TimestampType =>
301-
time.asInstanceOf[Long] / 1000000L
302-
case StringType if right.foldable =>
303-
if (constFormat != null) {
304-
val sdf = new SimpleDateFormat(constFormat)
305-
Try(sdf.parse(time.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null)
306-
} else {
307-
null
308-
}
309-
case StringType =>
310-
val formatString = format.asInstanceOf[UTF8String].toString
311-
val sdf = new SimpleDateFormat(formatString)
312-
Try(sdf.parse(time.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null)
298+
lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
299+
300+
override def eval(input: InternalRow): Any = {
301+
val t = left.eval(input)
302+
if (t == null) {
303+
null
304+
} else {
305+
left.dataType match {
306+
case DateType =>
307+
DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L
308+
case TimestampType =>
309+
t.asInstanceOf[Long] / 1000000L
310+
case StringType if right.foldable =>
311+
if (constFormat != null) {
312+
Try(new SimpleDateFormat(constFormat.toString).parse(
313+
t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null)
314+
} else {
315+
null
316+
}
317+
case StringType =>
318+
val f = format.eval(input)
319+
if (f == null) {
320+
null
321+
} else {
322+
val formatString = f.asInstanceOf[UTF8String].toString
323+
Try(new SimpleDateFormat(formatString).parse(
324+
t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null)
325+
}
326+
}
313327
}
314328
}
315329

316330
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
317331
left.dataType match {
332+
case StringType if right.foldable =>
333+
val sdf = classOf[SimpleDateFormat].getName
334+
val fString = if (constFormat == null) null else constFormat.toString
335+
val formatter = ctx.freshName("formatter")
336+
if (fString == null) {
337+
s"""
338+
boolean ${ev.isNull} = true;
339+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
340+
"""
341+
} else {
342+
val eval1 = left.gen(ctx)
343+
s"""
344+
${eval1.code}
345+
boolean ${ev.isNull} = ${eval1.isNull};
346+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
347+
if (!${ev.isNull}) {
348+
try {
349+
$sdf $formatter = new $sdf("$fString");
350+
${ev.primitive} =
351+
$formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L;
352+
} catch (java.lang.Throwable e) {
353+
${ev.isNull} = true;
354+
}
355+
}
356+
"""
357+
}
318358
case StringType =>
319359
val sdf = classOf[SimpleDateFormat].getName
320360
nullSafeCodeGen(ctx, ev, (string, format) => {
@@ -328,14 +368,26 @@ case class UnixTimestamp(left: Expression, right: Expression)
328368
"""
329369
})
330370
case TimestampType =>
331-
defineCodeGen(ctx, ev, (timestamp, format) => {
332-
s"""$timestamp / 1000000L"""
333-
})
371+
val eval1 = left.gen(ctx)
372+
s"""
373+
${eval1.code}
374+
boolean ${ev.isNull} = ${eval1.isNull};
375+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
376+
if (!${ev.isNull}) {
377+
${ev.primitive} = ${eval1.primitive} / 1000000L;
378+
}
379+
"""
334380
case DateType =>
335381
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
336-
defineCodeGen(ctx, ev, (date, format) => {
337-
s"""$dtu.daysToMillis($date) / 1000L"""
338-
})
382+
val eval1 = left.gen(ctx)
383+
s"""
384+
${eval1.code}
385+
boolean ${ev.isNull} = ${eval1.isNull};
386+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
387+
if (!${ev.isNull}) {
388+
${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L;
389+
}
390+
"""
339391
}
340392
}
341393
}
@@ -345,9 +397,12 @@ case class UnixTimestamp(left: Expression, right: Expression)
345397
* representing the timestamp of that moment in the current system time zone in the given
346398
* format. If the format is missing, using format like "1970-01-01 00:00:00".
347399
*/
348-
case class FromUnixTime(left: Expression, right: Expression)
400+
case class FromUnixTime(sec: Expression, format: Expression)
349401
extends BinaryExpression with ImplicitCastInputTypes {
350402

403+
override def left: Expression = sec
404+
override def right: Expression = format
405+
351406
def this(unix: Expression) = {
352407
this(unix, Literal("yyyy-MM-dd HH:mm:ss"))
353408
}
@@ -356,17 +411,68 @@ case class FromUnixTime(left: Expression, right: Expression)
356411

357412
override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType)
358413

359-
override protected def nullSafeEval(time: Any, format: Any): Any = {
360-
val sdf = new SimpleDateFormat(format.toString)
361-
UTF8String.fromString(sdf.format(new java.util.Date(time.asInstanceOf[Long] * 1000L)))
414+
lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
415+
416+
override def eval(input: InternalRow): Any = {
417+
val time = left.eval(input)
418+
if (time == null) {
419+
null
420+
} else {
421+
if (format.foldable) {
422+
if (constFormat == null) {
423+
null
424+
} else {
425+
Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format(
426+
new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null)
427+
}
428+
} else {
429+
val f = format.eval(input)
430+
if (f == null) {
431+
null
432+
} else {
433+
Try(UTF8String.fromString(new SimpleDateFormat(
434+
f.asInstanceOf[UTF8String].toString).format(new java.util.Date(
435+
time.asInstanceOf[Long] * 1000L)))).getOrElse(null)
436+
}
437+
}
438+
}
362439
}
363440

364441
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
365442
val sdf = classOf[SimpleDateFormat].getName
366-
defineCodeGen(ctx, ev, (seconds, format) => {
367-
s"""UTF8String.fromString((new $sdf($format.toString())).format(
368-
new java.sql.Timestamp($seconds * 1000L)))""".stripMargin
369-
})
443+
if (format.foldable) {
444+
if (constFormat == null) {
445+
s"""
446+
boolean ${ev.isNull} = true;
447+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
448+
"""
449+
} else {
450+
val t = left.gen(ctx)
451+
s"""
452+
${t.code}
453+
boolean ${ev.isNull} = ${t.isNull};
454+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
455+
if (!${ev.isNull}) {
456+
try {
457+
${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format(
458+
new java.sql.Timestamp(${t.primitive} * 1000L)));
459+
} catch (java.lang.Throwable e) {
460+
${ev.isNull} = true;
461+
}
462+
}
463+
"""
464+
}
465+
} else {
466+
nullSafeCodeGen(ctx, ev, (seconds, f) => {
467+
s"""
468+
try {
469+
${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format(
470+
new java.sql.Timestamp($seconds * 1000L)));
471+
} catch (java.lang.Throwable e) {
472+
${ev.isNull} = true;
473+
}""".stripMargin
474+
})
475+
}
370476
}
371477

372478
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
291291
checkEvaluation(
292292
FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null)
293293
checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null)
294+
checkEvaluation(
295+
FromUnixTime(Literal(0L), Literal("not a valid format")), null)
294296
}
295297

296298
test("unix_timestamp") {
@@ -323,7 +325,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
323325
UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null)
324326
checkEvaluation(
325327
UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null)
326-
checkEvaluation(UnixTimestamp(Literal(date1), Literal.create(null, StringType)), null)
328+
checkEvaluation(UnixTimestamp(
329+
Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L)
330+
checkEvaluation(
331+
UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
327332
}
328333

329334
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2092,14 +2092,18 @@ object functions {
20922092
def weekofyear(columnName: String): Column = weekofyear(Column(columnName))
20932093

20942094
/**
2095-
* Gets current Unix timestamp in seconds.
2095+
* Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
2096+
* representing the timestamp of that moment in the current system time zone in the given
2097+
* format.
20962098
* @group datetime_funcs
20972099
* @since 1.5.0
20982100
*/
20992101
def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss"))
21002102

21012103
/**
2102-
* Gets current Unix timestamp in seconds.
2104+
* Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
2105+
* representing the timestamp of that moment in the current system time zone in the given
2106+
* format.
21032107
* @group datetime_funcs
21042108
* @since 1.5.0
21052109
*/

sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ class DateFunctionsSuite extends QueryTest {
222222
checkAnswer(
223223
df.select(from_unixtime(col("a"), fmt3)),
224224
Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000)))))
225+
checkAnswer(
226+
df.selectExpr("from_unixtime(a)"),
227+
Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000)))))
228+
checkAnswer(
229+
df.selectExpr(s"from_unixtime(a, '$fmt2')"),
230+
Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000)))))
231+
checkAnswer(
232+
df.selectExpr(s"from_unixtime(a, '$fmt3')"),
233+
Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000)))))
225234
}
226235

227236
test("unix_timestamp") {
@@ -243,6 +252,14 @@ class DateFunctionsSuite extends QueryTest {
243252
Row(date1.getTime / 1000L), Row(date2.getTime / 1000L)))
244253
checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq(
245254
Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
255+
checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq(
256+
Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
257+
checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq(
258+
Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
259+
checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq(
260+
Row(date1.getTime / 1000L), Row(date2.getTime / 1000L)))
261+
checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq(
262+
Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L)))
246263
}
247264

248265
}

0 commit comments

Comments
 (0)