Skip to content

Commit 356df78

Browse files
committed
rely on cast mechanism of Spark. Simplified implementation
1 parent 02efc5d commit 356df78

File tree

3 files changed

+65
-160
lines changed

3 files changed

+65
-160
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetime.scala

Lines changed: 41 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -26,115 +26,48 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
2626
import org.apache.spark.sql.types._
2727
import org.apache.spark.unsafe.types.UTF8String
2828

29-
abstract class DateFormatExpression extends Expression { self: Product =>
29+
abstract class DateFormatExpression extends UnaryExpression with ExpectsInputTypes {
30+
self: Product =>
3031

3132
protected val format: String
3233

33-
protected val caller: String
34-
35-
protected val date: Expression
36-
37-
override def foldable: Boolean = date.foldable
38-
39-
override def nullable: Boolean = true
40-
41-
override def children: Seq[Expression] = Seq(date)
34+
override def expectedChildTypes: Seq[DataType] = Seq(TimestampType)
4235

4336
override def eval(input: InternalRow): Any = {
44-
val valueLeft = date.eval(input)
37+
val valueLeft = child.eval(input)
4538
if (valueLeft == null) {
4639
null
4740
} else {
4841
if (format == null) {
4942
null
5043
} else {
5144
val sdf = new SimpleDateFormat(format)
52-
date.dataType match {
53-
case TimestampType =>
54-
UTF8String.fromString(sdf.format(new Date(valueLeft.asInstanceOf[Long] / 10000)))
55-
case DateType =>
56-
UTF8String.fromString(sdf.format(DateTimeUtils.toJavaDate(valueLeft.asInstanceOf[Int])))
57-
case StringType =>
58-
UTF8String.fromString(
59-
sdf.format(DateTimeUtils.stringToTime(valueLeft.toString)))
60-
}
45+
UTF8String.fromString(sdf.format(new Date(valueLeft.asInstanceOf[Long] / 10000)))
6146
}
6247
}
6348
}
6449

65-
override def checkInputDataTypes(): TypeCheckResult =
66-
date.dataType match {
67-
case null => TypeCheckResult.TypeCheckSuccess
68-
case _: DateType => TypeCheckResult.TypeCheckSuccess
69-
case _: TimestampType => TypeCheckResult.TypeCheckSuccess
70-
case _: StringType => TypeCheckResult.TypeCheckSuccess
71-
case _ =>
72-
TypeCheckResult.TypeCheckFailure(s"$caller accepts date types as argument, " +
73-
s" not ${date.dataType}")
74-
}
75-
76-
77-
/**
78-
* Called by date format expressions to generate a code block that returns the result
79-
*
80-
* As an example, the following parse the result to int
81-
* {{{
82-
* defineCodeGen(ctx, ev, c => s"Integer.parseInt($c.toString())")
83-
* }}}
84-
*
85-
* @param f function that accepts a variable name and returns Java code to parse an
86-
* [[UTF8String]] to the expected output type
87-
*/
88-
89-
protected def defineCodeGen(
50+
override protected def defineCodeGen(
9051
ctx: CodeGenContext,
9152
ev: GeneratedExpressionCode,
9253
f: String => String): String = {
9354

9455
val sdf = classOf[SimpleDateFormat].getName
95-
val dtUtils = "org.apache.spark.sql.catalyst.util.DateTimeUtils"
96-
97-
val eval1 = date.gen(ctx)
98-
99-
val parseInput = date.dataType match {
100-
case StringType => s"new java.sql.Date($dtUtils.stringToTime(${eval1.primitive}.toString()).getTime())"
101-
case TimestampType => s"new java.sql.Date(${eval1.primitive} / 10000)"
102-
case DateType => s"$dtUtils.toJavaDate(${eval1.primitive})"
103-
}
104-
105-
s"""
106-
${eval1.code}
107-
boolean ${ev.isNull} = ${eval1.isNull};
108-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
109-
if (!${ev.isNull}) {
110-
$sdf sdf = new $sdf("$format");
111-
${ctx.stringType} s = ${ctx.stringType}.fromString(sdf.format($parseInput));
112-
${ev.primitive} = ${f("s")};
113-
} else {
114-
${ev.isNull} = true;
115-
}
116-
"""
56+
super.defineCodeGen(ctx, ev, (x) => {
57+
f(s"""${ctx.stringType}.fromString((new $sdf("$format")).format(new java.sql.Date($x / 10000)))""")
58+
})
11759
}
11860

11961
}
12062

121-
case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression {
63+
case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression
64+
with ExpectsInputTypes {
12265

12366
override def dataType: DataType = StringType
12467

125-
override def checkInputDataTypes(): TypeCheckResult =
126-
(left.dataType, right.dataType) match {
127-
case (null, _) => TypeCheckResult.TypeCheckSuccess
128-
case (_, null) => TypeCheckResult.TypeCheckSuccess
129-
case (_: DateType, _: StringType) => TypeCheckResult.TypeCheckSuccess
130-
case (_: TimestampType, _: StringType) => TypeCheckResult.TypeCheckSuccess
131-
case (_: StringType, _: StringType) => TypeCheckResult.TypeCheckSuccess
132-
case _ =>
133-
TypeCheckResult.TypeCheckFailure(s"DateFormat accepts date types as first argument, " +
134-
s"and string types as second, not ${left.dataType} and ${right.dataType}")
135-
}
136-
13768
override def toString: String = s"DateFormat($left, $right)"
69+
70+
override def expectedChildTypes: Seq[DataType] = Seq(TimestampType, StringType)
13871

13972
override def eval(input: InternalRow): Any = {
14073
val valueLeft = left.eval(input)
@@ -146,53 +79,23 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
14679
null
14780
} else {
14881
val sdf = new SimpleDateFormat(valueRight.toString)
149-
left.dataType match {
150-
case TimestampType =>
151-
UTF8String.fromString(sdf.format(new Date(valueLeft.asInstanceOf[Long] / 10000)))
152-
case DateType =>
153-
UTF8String.fromString(sdf.format(DateTimeUtils.toJavaDate(valueLeft.asInstanceOf[Int])))
154-
case StringType =>
155-
UTF8String.fromString(
156-
sdf.format(DateTimeUtils.stringToTime(valueLeft.toString)))
157-
}
82+
UTF8String.fromString(sdf.format(new Date(valueLeft.asInstanceOf[Long] / 10000)))
15883
}
15984
}
16085
}
16186

16287
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
16388
val sdf = classOf[SimpleDateFormat].getName
164-
val dtUtils = "org.apache.spark.sql.catalyst.util.DateTimeUtils"
165-
166-
val eval1 = left.gen(ctx)
167-
val eval2 = right.gen(ctx)
168-
169-
val parseInput = left.dataType match {
170-
case StringType => s"new java.sql.Date($dtUtils.stringToTime(${eval1.primitive}.toString()).getTime())"
171-
case TimestampType => s"new java.sql.Date(${eval1.primitive} / 10000)"
172-
case DateType => s"$dtUtils.toJavaDate(${eval1.primitive})"
173-
}
174-
175-
s"""
176-
${eval1.code}
177-
${eval2.code}
178-
boolean ${ev.isNull} = ${eval1.isNull} || ${eval2.isNull};
179-
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
180-
if (!${ev.isNull}) {
181-
$sdf sdf = new $sdf(${eval2.primitive}.toString());
182-
${ev.primitive} = ${ctx.stringType}.fromString(sdf.format($parseInput));
183-
} else {
184-
${ev.isNull} = true;
185-
}
186-
"""
89+
defineCodeGen(ctx, ev, (x, y) => {
90+
s"""${ctx.stringType}.fromString((new $sdf($y.toString())).format(new java.sql.Date($x / 10000)))"""
91+
})
18792
}
18893
}
18994

190-
case class Year(date: Expression) extends DateFormatExpression {
95+
case class Year(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
19196

19297
override protected val format: String = "y"
19398

194-
override protected val caller: String = "Year"
195-
19699
override def dataType: DataType = IntegerType
197100

198101
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -205,14 +108,14 @@ case class Year(date: Expression) extends DateFormatExpression {
205108
case s: UTF8String => s.toString.toInt
206109
}
207110
}
111+
112+
override def toString: String = s"Year($child)"
208113
}
209114

210-
case class Quarter(date: Expression) extends DateFormatExpression {
115+
case class Quarter(child: Expression) extends DateFormatExpression {
211116

212117
override protected val format: String = "M"
213118

214-
override protected val caller: String = "Quarter"
215-
216119
override def dataType: DataType = IntegerType
217120

218121
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -225,18 +128,16 @@ case class Quarter(date: Expression) extends DateFormatExpression {
225128
case s: UTF8String => (s.toString.toInt - 1) / 3 + 1
226129
}
227130
}
131+
132+
override def toString: String = s"Quarter($child)"
228133
}
229134

230-
case class Month(date: Expression) extends DateFormatExpression {
135+
case class Month(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
231136

232137
override protected val format: String = "M"
233138

234-
override protected val caller: String = "Month"
235-
236139
override def dataType: DataType = IntegerType
237140

238-
override def nullable: Boolean = true
239-
240141
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
241142
defineCodeGen(ctx, ev, c => s"Integer.parseInt($c.toString())")
242143
}
@@ -247,14 +148,14 @@ case class Month(date: Expression) extends DateFormatExpression {
247148
case s: UTF8String => s.toString.toInt
248149
}
249150
}
151+
152+
override def toString: String = s"Month($child)"
250153
}
251154

252-
case class Day(date: Expression) extends DateFormatExpression {
155+
case class Day(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
253156

254157
override protected val format: String = "d"
255158

256-
override protected val caller: String = "Day"
257-
258159
override def dataType: DataType = IntegerType
259160

260161
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -267,14 +168,14 @@ case class Day(date: Expression) extends DateFormatExpression {
267168
case s: UTF8String => s.toString.toInt
268169
}
269170
}
171+
172+
override def toString: String = s"Day($child)"
270173
}
271174

272-
case class Hour(date: Expression) extends DateFormatExpression {
175+
case class Hour(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
273176

274177
override protected val format: String = "H"
275178

276-
override protected val caller: String = "Hour"
277-
278179
override def dataType: DataType = IntegerType
279180

280181
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -287,14 +188,14 @@ case class Hour(date: Expression) extends DateFormatExpression {
287188
case s: UTF8String => s.toString.toInt
288189
}
289190
}
191+
192+
override def toString: String = s"Hour($child)"
290193
}
291194

292-
case class Minute(date: Expression) extends DateFormatExpression {
195+
case class Minute(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
293196

294197
override protected val format: String = "m"
295198

296-
override protected val caller: String = "Minute"
297-
298199
override def dataType: DataType = IntegerType
299200

300201
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -307,14 +208,14 @@ case class Minute(date: Expression) extends DateFormatExpression {
307208
case s: UTF8String => s.toString.toInt
308209
}
309210
}
211+
212+
override def toString: String = s"Minute($child)"
310213
}
311214

312-
case class Second(date: Expression) extends DateFormatExpression {
215+
case class Second(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
313216

314217
override protected val format: String = "s"
315218

316-
override protected val caller: String = "Second"
317-
318219
override def dataType: DataType = IntegerType
319220

320221
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -327,14 +228,14 @@ case class Second(date: Expression) extends DateFormatExpression {
327228
case s: UTF8String => s.toString.toInt
328229
}
329230
}
231+
232+
override def toString: String = s"Second($child)"
330233
}
331234

332-
case class WeekOfYear(date: Expression) extends DateFormatExpression {
235+
case class WeekOfYear(child: Expression) extends DateFormatExpression with ExpectsInputTypes {
333236

334237
override protected val format: String = "w"
335238

336-
override protected val caller: String = "WeekOfYear"
337-
338239
override def dataType: DataType = IntegerType
339240

340241
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -347,4 +248,6 @@ case class WeekOfYear(date: Expression) extends DateFormatExpression {
347248
case s: UTF8String => s.toString.toInt
348249
}
349250
}
251+
252+
override def toString: String = s"WeekOfYear($child)"
350253
}

0 commit comments

Comments
 (0)