Skip to content

Commit d09acc1

Browse files
committed
implement getLastDayOfMonth to avoid repeated evaluation
1 parent d857ec3 commit d09acc1

File tree

2 files changed

+52
-42
lines changed

2 files changed

+52
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu
7474

7575
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
7676
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
77-
defineCodeGen(ctx, ev, (c) =>
78-
s"""$dtu.getHours($c)"""
79-
)
77+
defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)")
8078
}
8179
}
8280

@@ -92,9 +90,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn
9290

9391
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
9492
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
95-
defineCodeGen(ctx, ev, (c) =>
96-
s"""$dtu.getMinutes($c)"""
97-
)
93+
defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)")
9894
}
9995
}
10096

@@ -110,9 +106,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn
110106

111107
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
112108
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
113-
defineCodeGen(ctx, ev, (c) =>
114-
s"""$dtu.getSeconds($c)"""
115-
)
109+
defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)")
116110
}
117111
}
118112

@@ -128,9 +122,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
128122

129123
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
130124
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
131-
defineCodeGen(ctx, ev, (c) =>
132-
s"""$dtu.getDayInYear($c)"""
133-
)
125+
defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)")
134126
}
135127
}
136128

@@ -147,9 +139,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
147139

148140
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
149141
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
150-
defineCodeGen(ctx, ev, c =>
151-
s"""$dtu.getYear($c)"""
152-
)
142+
defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)")
153143
}
154144
}
155145

@@ -165,9 +155,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
165155

166156
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
167157
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
168-
defineCodeGen(ctx, ev, (c) =>
169-
s"""$dtu.getQuarter($c)"""
170-
)
158+
defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)")
171159
}
172160
}
173161

@@ -183,9 +171,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
183171

184172
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
185173
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
186-
defineCodeGen(ctx, ev, (c) =>
187-
s"""$dtu.getMonth($c)"""
188-
)
174+
defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)")
189175
}
190176
}
191177

@@ -201,9 +187,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa
201187

202188
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
203189
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
204-
defineCodeGen(ctx, ev, (c) =>
205-
s"""$dtu.getDayOfMonth($c)"""
206-
)
190+
defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)")
207191
}
208192
}
209193

@@ -226,7 +210,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
226210
}
227211

228212
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
229-
nullSafeCodeGen(ctx, ev, (time) => {
213+
nullSafeCodeGen(ctx, ev, time => {
230214
val cal = classOf[Calendar].getName
231215
val c = ctx.freshName("cal")
232216
ctx.addMutableState(cal, c,
@@ -250,8 +234,6 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
250234

251235
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
252236

253-
override def prettyName: String = "date_format"
254-
255237
override protected def nullSafeEval(timestamp: Any, format: Any): Any = {
256238
val sdf = new SimpleDateFormat(format.toString)
257239
UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000)))
@@ -264,6 +246,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
264246
.format(new java.sql.Date($timestamp / 1000)))"""
265247
})
266248
}
249+
250+
override def prettyName: String = "date_format"
267251
}
268252

269253
/**
@@ -277,15 +261,12 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
277261
override def dataType: DataType = DateType
278262

279263
override def nullSafeEval(date: Any): Any = {
280-
val days = date.asInstanceOf[Int]
281-
DateTimeUtils.getLastDayOfMonth(days)
264+
DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int])
282265
}
283266

284267
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
285268
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
286-
defineCodeGen(ctx, ev, (sd) => {
287-
s"$dtu.getLastDayOfMonth($sd)"
288-
})
269+
defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)")
289270
}
290271

291272
override def prettyName: String = "last_day"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -601,22 +601,51 @@ object DateTimeUtils {
601601
}
602602

603603
/**
604-
* number of days in a non-leap year.
604+
* Returns the number of days till the month end.
605+
* if the `date` itself is the last day of a month, just return 0.
605606
*/
606-
private[this] val daysInNormalYear = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31)
607+
private[this] def daysToMonthEnd(date: Int): Int = {
608+
var (year, dayInYear) = getYearAndDayInYear(date)
609+
if (isLeapYear(year)) {
610+
if (dayInYear > 31 && dayInYear <= 60) {
611+
return 60 - dayInYear
612+
} else if (dayInYear > 60) {
613+
dayInYear = dayInYear - 1
614+
}
615+
}
616+
val lastDayOfMonth = if (dayInYear <= 31) {
617+
31
618+
} else if (dayInYear <= 59) {
619+
59
620+
} else if (dayInYear <= 90) {
621+
90
622+
} else if (dayInYear <= 120) {
623+
120
624+
} else if (dayInYear <= 151) {
625+
151
626+
} else if (dayInYear <= 181) {
627+
181
628+
} else if (dayInYear <= 212) {
629+
212
630+
} else if (dayInYear <= 243) {
631+
243
632+
} else if (dayInYear <= 273) {
633+
273
634+
} else if (dayInYear <= 304) {
635+
304
636+
} else if (dayInYear <= 334) {
637+
334
638+
} else {
639+
365
640+
}
641+
lastDayOfMonth - dayInYear
642+
}
607643

608644
/**
609645
* Returns last day of the month for the given date. The date is expressed in days
610646
* since 1.1.1970.
611647
*/
612648
def getLastDayOfMonth(date: Int): Int = {
613-
val dayOfMonth = getDayOfMonth(date)
614-
val month = getMonth(date)
615-
if (month == 2 && isLeapYear(getYear(date))) {
616-
date + daysInNormalYear(month - 1) + 1 - dayOfMonth
617-
} else {
618-
date + daysInNormalYear(month - 1) - dayOfMonth
619-
}
649+
date + daysToMonthEnd(date)
620650
}
621-
622651
}

0 commit comments

Comments
 (0)