Skip to content

Commit ec87c69

Browse files
committed
[SPARK-8119] bug fixing and refactoring
1 parent 1358cdc commit ec87c69

File tree

1 file changed

+29
-36
lines changed

1 file changed

+29
-36
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -61,48 +61,46 @@ case class CurrentTimestamp() extends LeafExpression {
6161
}
6262
}
6363

64-
/**
65-
* Abstract class for create time format expressions.
66-
*/
67-
abstract class TimeFormatExpression extends UnaryExpression with ExpectsInputTypes {
68-
self: Product =>
69-
70-
protected val factorToMilli: Int
71-
72-
protected val cntPerInterval: Int
64+
case class Hour(child: Expression) extends UnaryExpression with ExpectsInputTypes {
7365

7466
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
7567

7668
override def dataType: DataType = IntegerType
7769

7870
override protected def nullSafeEval(timestamp: Any): Any = {
7971
val time = timestamp.asInstanceOf[Long] / 1000
80-
val longTime: Long = time + TimeZone.getDefault.getOffset(time)
81-
((longTime / factorToMilli) % cntPerInterval).toInt
72+
val longTime: Long = time.asInstanceOf[Long] + TimeZone.getDefault.getOffset(time)
73+
((longTime / (1000 * 3600)) % 24).toInt
8274
}
8375

8476
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
8577
val tz = classOf[TimeZone].getName
8678
defineCodeGen(ctx, ev, (c) =>
87-
s"""(${ctx.javaType(dataType)})
88-
((($c / 1000) + $tz.getDefault().getOffset($c / 1000))
89-
/ $factorToMilli % $cntPerInterval)"""
79+
s"""(int) ((($c / 1000) + $tz.getDefault().getOffset($c / 1000))
80+
/ (1000 * 3600) % 24)""".stripMargin
9081
)
9182
}
9283
}
9384

94-
case class Hour(child: Expression) extends TimeFormatExpression {
95-
96-
override protected val factorToMilli: Int = 1000 * 3600
85+
case class Minute(child: Expression) extends UnaryExpression with ExpectsInputTypes {
9786

98-
override protected val cntPerInterval: Int = 24
99-
}
87+
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
10088

101-
case class Minute(child: Expression) extends TimeFormatExpression {
89+
override def dataType: DataType = IntegerType
10290

103-
override protected val factorToMilli: Int = 1000 * 60
91+
override protected def nullSafeEval(timestamp: Any): Any = {
92+
val time = timestamp.asInstanceOf[Long] / 1000
93+
val longTime: Long = time.asInstanceOf[Long] + TimeZone.getDefault.getOffset(time)
94+
((longTime / (1000 * 60)) % 60).toInt
95+
}
10496

105-
override protected val cntPerInterval: Int = 60
97+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
98+
val tz = classOf[TimeZone].getName
99+
defineCodeGen(ctx, ev, (c) =>
100+
s"""(int) ((($c / 1000) + $tz.getDefault().getOffset($c / 1000))
101+
/ (1000 * 60) % 60)""".stripMargin
102+
)
103+
}
106104
}
107105

108106
case class Second(child: Expression) extends UnaryExpression with ExpectsInputTypes {
@@ -122,15 +120,6 @@ case class Second(child: Expression) extends UnaryExpression with ExpectsInputTy
122120
}
123121
}
124122

125-
private[sql] object DateFormatExpression {
126-
127-
def isLeapYear(year: Int): Boolean = {
128-
(year % 4) == 0 && ((year % 100) != 0 || (year % 400) == 0)
129-
}
130-
131-
132-
}
133-
134123
abstract class DateFormatExpression extends UnaryExpression with ExpectsInputTypes {
135124
self: Product =>
136125

@@ -140,6 +129,10 @@ abstract class DateFormatExpression extends UnaryExpression with ExpectsInputTyp
140129
// this is year -17999, calculation: 50 * daysIn400Year
141130
val toYearZero = to2001 + 7304850
142131

132+
protected def isLeapYear(year: Int): Boolean = {
133+
(year % 4) == 0 && ((year % 100) != 0 || (year % 400) == 0)
134+
}
135+
143136
private[this] def yearBoundary(year: Int): Int = {
144137
year * 365 + ((year / 4 ) - (year / 100) + (year / 400))
145138
}
@@ -178,7 +171,7 @@ abstract class DateFormatExpression extends UnaryExpression with ExpectsInputTyp
178171
s"""
179172
int $daysIn400Years = 146097;
180173
int $to2001 = -11323;
181-
int $toYearZero = to2001 + 7304850;
174+
int $toYearZero = $to2001 + 7304850;
182175

183176
int $daysNormalized = $input + $toYearZero;
184177
int $numOfQuarterCenturies = $daysNormalized / $daysIn400Years;
@@ -188,7 +181,7 @@ abstract class DateFormatExpression extends UnaryExpression with ExpectsInputTyp
188181
$years = ($daysInThis400 > $years * 365 + (($years / 4 ) - ($years / 100) +
189182
($years / 400))) ? $years : $years - 1;
190183

191-
int $year = (2001 - 20000) + 400 * $numOfQuarterCenturies + years;
184+
int $year = (2001 - 20000) + 400 * $numOfQuarterCenturies + $years;
192185
int $dayInYear = $daysInThis400 -
193186
($years * 365 + (($years / 4 ) - ($years / 100) + ($years / 400)));
194187
${f(year, dayInYear)};
@@ -231,7 +224,7 @@ case class Quarter(child: Expression) extends DateFormatExpression {
231224

232225
override protected def nullSafeEval(input: Any): Any = {
233226
val (year, dayInYear) = calculateYearAndDayInYear(input.asInstanceOf[Int])
234-
val leap = if (DateFormatExpression.isLeapYear(year)) 1 else 0
227+
val leap = if (isLeapYear(year)) 1 else 0
235228
dayInYear match {
236229
case i: Int if i <= 90 + leap => 1
237230
case i: Int if i <= 181 + leap => 2
@@ -263,7 +256,7 @@ case class Month(child: Expression) extends DateFormatExpression {
263256

264257
override protected def nullSafeEval(input: Any): Any = {
265258
val (year, dayInYear) = calculateYearAndDayInYear(input.asInstanceOf[Int])
266-
val leap = if (DateFormatExpression.isLeapYear(year)) 1 else 0
259+
val leap = if (isLeapYear(year)) 1 else 0
267260
dayInYear match {
268261
case i: Int if i <= 31 => 1
269262
case i: Int if i <= 59 + leap => 2
@@ -325,7 +318,7 @@ case class Day(child: Expression) extends DateFormatExpression with ExpectsInput
325318

326319
override protected def nullSafeEval(input: Any): Any = {
327320
val (year, dayInYear) = calculateYearAndDayInYear(input.asInstanceOf[Int])
328-
val leap = if (DateFormatExpression.isLeapYear(year)) 1 else 0
321+
val leap = if (isLeapYear(year)) 1 else 0
329322
dayInYear match {
330323
case i: Int if i <= 31 => i
331324
case i: Int if i <= 59 + leap => i - 31

0 commit comments

Comments
 (0)