Skip to content

Commit c2b88ae

Browse files
committed
[SPARK-2209][SQL] Cast shouldn't do null check twice.
1 parent 5464e79 commit c2b88ae

File tree

1 file changed

+162
-116
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+162
-116
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 162 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -24,72 +24,89 @@ import org.apache.spark.sql.catalyst.types._
2424
/** Cast the child expression to the target data type. */
2525
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
2626
override def foldable = child.foldable
27-
def nullable = (child.dataType, dataType) match {
27+
28+
override def nullable = (child.dataType, dataType) match {
2829
case (StringType, _: NumericType) => true
2930
case (StringType, TimestampType) => true
3031
case _ => child.nullable
3132
}
33+
3234
override def toString = s"CAST($child, $dataType)"
3335

3436
type EvaluatedType = Any
3537

36-
def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
37-
null
38-
} else {
39-
func(a.asInstanceOf[T])
40-
}
38+
// [[func]] assumes the input is no longer null because eval already does the null check.
39+
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
4140

4241
// UDFToString
43-
def castToString: Any => Any = child.dataType match {
44-
case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
45-
case _ => nullOrCast[Any](_, _.toString)
42+
private[this] def castToString: Any => Any = child.dataType match {
43+
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
44+
case _ => buildCast[Any](_, _.toString)
4645
}
4746

4847
// BinaryConverter
49-
def castToBinary: Any => Any = child.dataType match {
50-
case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
48+
private[this] def castToBinary: Any => Any = child.dataType match {
49+
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
5150
}
5251

5352
// UDFToBoolean
54-
def castToBoolean: Any => Any = child.dataType match {
55-
case StringType => nullOrCast[String](_, _.length() != 0)
56-
case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
57-
case LongType => nullOrCast[Long](_, _ != 0)
58-
case IntegerType => nullOrCast[Int](_, _ != 0)
59-
case ShortType => nullOrCast[Short](_, _ != 0)
60-
case ByteType => nullOrCast[Byte](_, _ != 0)
61-
case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
62-
case DoubleType => nullOrCast[Double](_, _ != 0)
63-
case FloatType => nullOrCast[Float](_, _ != 0)
53+
private[this] def castToBoolean: Any => Any = child.dataType match {
54+
case StringType =>
55+
buildCast[String](_, _.length() != 0)
56+
case TimestampType =>
57+
buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
58+
case LongType =>
59+
buildCast[Long](_, _ != 0)
60+
case IntegerType =>
61+
buildCast[Int](_, _ != 0)
62+
case ShortType =>
63+
buildCast[Short](_, _ != 0)
64+
case ByteType =>
65+
buildCast[Byte](_, _ != 0)
66+
case DecimalType =>
67+
buildCast[BigDecimal](_, _ != 0)
68+
case DoubleType =>
69+
buildCast[Double](_, _ != 0)
70+
case FloatType =>
71+
buildCast[Float](_, _ != 0)
6472
}
6573

6674
// TimestampConverter
67-
def castToTimestamp: Any => Any = child.dataType match {
68-
case StringType => nullOrCast[String](_, s => {
69-
// Throw away extra if more than 9 decimal places
70-
val periodIdx = s.indexOf(".");
71-
var n = s
72-
if (periodIdx != -1) {
73-
if (n.length() - periodIdx > 9) {
74-
n = n.substring(0, periodIdx + 10)
75+
private[this] def castToTimestamp: Any => Any = child.dataType match {
76+
case StringType =>
77+
buildCast[String](_, s => {
78+
// Throw away extra if more than 9 decimal places
79+
val periodIdx = s.indexOf(".")
80+
var n = s
81+
if (periodIdx != -1) {
82+
if (n.length() - periodIdx > 9) {
83+
n = n.substring(0, periodIdx + 10)
84+
}
7585
}
76-
}
77-
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
78-
})
79-
case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
80-
case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
81-
case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
82-
case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
83-
case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
86+
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null }
87+
})
88+
case BooleanType =>
89+
buildCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
90+
case LongType =>
91+
buildCast[Long](_, l => new Timestamp(l * 1000))
92+
case IntegerType =>
93+
buildCast[Int](_, i => new Timestamp(i * 1000))
94+
case ShortType =>
95+
buildCast[Short](_, s => new Timestamp(s * 1000))
96+
case ByteType =>
97+
buildCast[Byte](_, b => new Timestamp(b * 1000))
8498
// TimestampWritable.decimalToTimestamp
85-
case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
99+
case DecimalType =>
100+
buildCast[BigDecimal](_, d => decimalToTimestamp(d))
86101
// TimestampWritable.doubleToTimestamp
87-
case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
102+
case DoubleType =>
103+
buildCast[Double](_, d => decimalToTimestamp(d))
88104
// TimestampWritable.floatToTimestamp
89-
case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
105+
case FloatType =>
106+
buildCast[Float](_, f => decimalToTimestamp(f))
90107
}
91108

92-
private def decimalToTimestamp(d: BigDecimal) = {
109+
private[this] def decimalToTimestamp(d: BigDecimal) = {
93110
val seconds = d.longValue()
94111
val bd = (d - seconds) * 1000000000
95112
val nanos = bd.intValue()
@@ -104,85 +121,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
104121
}
105122

106123
// Timestamp to long, converting milliseconds to seconds
107-
private def timestampToLong(ts: Timestamp) = ts.getTime / 1000
124+
private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000
108125

109-
private def timestampToDouble(ts: Timestamp) = {
126+
private[this] def timestampToDouble(ts: Timestamp) = {
110127
// First part is the seconds since the beginning of time, followed by nanosecs.
111128
ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
112129
}
113130

114-
def castToLong: Any => Any = child.dataType match {
115-
case StringType => nullOrCast[String](_, s => try s.toLong catch {
116-
case _: NumberFormatException => null
117-
})
118-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
119-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
120-
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
121-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
122-
}
123-
124-
def castToInt: Any => Any = child.dataType match {
125-
case StringType => nullOrCast[String](_, s => try s.toInt catch {
126-
case _: NumberFormatException => null
127-
})
128-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
129-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt)
130-
case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
131-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
132-
}
133-
134-
def castToShort: Any => Any = child.dataType match {
135-
case StringType => nullOrCast[String](_, s => try s.toShort catch {
136-
case _: NumberFormatException => null
137-
})
138-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort)
139-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
140-
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
141-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
142-
}
143-
144-
def castToByte: Any => Any = child.dataType match {
145-
case StringType => nullOrCast[String](_, s => try s.toByte catch {
146-
case _: NumberFormatException => null
147-
})
148-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte)
149-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
150-
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
151-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
152-
}
153-
154-
def castToDecimal: Any => Any = child.dataType match {
155-
case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
156-
case _: NumberFormatException => null
157-
})
158-
case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
131+
private[this] def castToLong: Any => Any = child.dataType match {
132+
case StringType =>
133+
buildCast[String](_, s => try s.toLong catch {
134+
case _: NumberFormatException => null
135+
})
136+
case BooleanType =>
137+
buildCast[Boolean](_, b => if (b) 1L else 0L)
138+
case TimestampType =>
139+
buildCast[Timestamp](_, t => timestampToLong(t))
140+
case DecimalType =>
141+
buildCast[BigDecimal](_, _.toLong)
142+
case x: NumericType =>
143+
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
144+
}
145+
146+
private[this] def castToInt: Any => Any = child.dataType match {
147+
case StringType =>
148+
buildCast[String](_, s => try s.toInt catch {
149+
case _: NumberFormatException => null
150+
})
151+
case BooleanType =>
152+
buildCast[Boolean](_, b => if (b) 1 else 0)
153+
case TimestampType =>
154+
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
155+
case DecimalType =>
156+
buildCast[BigDecimal](_, _.toInt)
157+
case x: NumericType =>
158+
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
159+
}
160+
161+
private[this] def castToShort: Any => Any = child.dataType match {
162+
case StringType =>
163+
buildCast[String](_, s => try s.toShort catch {
164+
case _: NumberFormatException => null
165+
})
166+
case BooleanType =>
167+
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
168+
case TimestampType =>
169+
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
170+
case DecimalType =>
171+
buildCast[BigDecimal](_, _.toShort)
172+
case x: NumericType =>
173+
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
174+
}
175+
176+
private[this] def castToByte: Any => Any = child.dataType match {
177+
case StringType =>
178+
buildCast[String](_, s => try s.toByte catch {
179+
case _: NumberFormatException => null
180+
})
181+
case BooleanType =>
182+
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
183+
case TimestampType =>
184+
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
185+
case DecimalType =>
186+
buildCast[BigDecimal](_, _.toByte)
187+
case x: NumericType =>
188+
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
189+
}
190+
191+
private[this] def castToDecimal: Any => Any = child.dataType match {
192+
case StringType =>
193+
buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
194+
case _: NumberFormatException => null
195+
})
196+
case BooleanType =>
197+
buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
159198
case TimestampType =>
160199
// Note that we lose precision here.
161-
nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
162-
case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
163-
}
164-
165-
def castToDouble: Any => Any = child.dataType match {
166-
case StringType => nullOrCast[String](_, s => try s.toDouble catch {
167-
case _: NumberFormatException => null
168-
})
169-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
170-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
171-
case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
172-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
173-
}
174-
175-
def castToFloat: Any => Any = child.dataType match {
176-
case StringType => nullOrCast[String](_, s => try s.toFloat catch {
177-
case _: NumberFormatException => null
178-
})
179-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
180-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
181-
case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
182-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
200+
buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
201+
case x: NumericType =>
202+
b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
203+
}
204+
205+
private[this] def castToDouble: Any => Any = child.dataType match {
206+
case StringType =>
207+
buildCast[String](_, s => try s.toDouble catch {
208+
case _: NumberFormatException => null
209+
})
210+
case BooleanType =>
211+
buildCast[Boolean](_, b => if (b) 1d else 0d)
212+
case TimestampType =>
213+
buildCast[Timestamp](_, t => timestampToDouble(t))
214+
case DecimalType =>
215+
buildCast[BigDecimal](_, _.toDouble)
216+
case x: NumericType =>
217+
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
218+
}
219+
220+
private[this] def castToFloat: Any => Any = child.dataType match {
221+
case StringType =>
222+
buildCast[String](_, s => try s.toFloat catch {
223+
case _: NumberFormatException => null
224+
})
225+
case BooleanType =>
226+
buildCast[Boolean](_, b => if (b) 1f else 0f)
227+
case TimestampType =>
228+
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
229+
case DecimalType =>
230+
buildCast[BigDecimal](_, _.toFloat)
231+
case x: NumericType =>
232+
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
183233
}
184234

185-
private lazy val cast: Any => Any = dataType match {
235+
private[this] lazy val cast: Any => Any = dataType match {
186236
case StringType => castToString
187237
case BinaryType => castToBinary
188238
case DecimalType => castToDecimal
@@ -198,10 +248,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
198248

199249
override def eval(input: Row): Any = {
200250
val evaluated = child.eval(input)
201-
if (evaluated == null) {
202-
null
203-
} else {
204-
cast(evaluated)
205-
}
251+
if (evaluated == null) null else cast(evaluated)
206252
}
207253
}

0 commit comments

Comments
 (0)