Skip to content

Commit c55bbb4

Browse files
committed
[SPARK-2209][SQL] Cast shouldn't do null check twice.
Also took the chance to clean up cast a little bit. Too many arrows on each line before! Author: Reynold Xin <[email protected]> Closes apache#1143 from rxin/cast and squashes the following commits: dd006cb [Reynold Xin] Code review feedback. c2b88ae [Reynold Xin] [SPARK-2209][SQL] Cast shouldn't do null check twice.
1 parent 6175640 commit c55bbb4

File tree

1 file changed

+159
-115
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions

1 file changed

+159
-115
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 159 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._
2424
/** Cast the child expression to the target data type. */
2525
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
2626
override def foldable = child.foldable
27-
def nullable = (child.dataType, dataType) match {
27+
28+
override def nullable = (child.dataType, dataType) match {
2829
case (StringType, _: NumericType) => true
2930
case (StringType, TimestampType) => true
3031
case _ => child.nullable
3132
}
33+
3234
override def toString = s"CAST($child, $dataType)"
3335

3436
type EvaluatedType = Any
3537

36-
def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
37-
null
38-
} else {
39-
func(a.asInstanceOf[T])
40-
}
38+
// [[func]] assumes the input is no longer null because eval already does the null check.
39+
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
4140

4241
// UDFToString
43-
def castToString: Any => Any = child.dataType match {
44-
case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
45-
case _ => nullOrCast[Any](_, _.toString)
42+
private[this] def castToString: Any => Any = child.dataType match {
43+
case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8"))
44+
case _ => buildCast[Any](_, _.toString)
4645
}
4746

4847
// BinaryConverter
49-
def castToBinary: Any => Any = child.dataType match {
50-
case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
48+
private[this] def castToBinary: Any => Any = child.dataType match {
49+
case StringType => buildCast[String](_, _.getBytes("UTF-8"))
5150
}
5251

5352
// UDFToBoolean
54-
def castToBoolean: Any => Any = child.dataType match {
55-
case StringType => nullOrCast[String](_, _.length() != 0)
56-
case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
57-
case LongType => nullOrCast[Long](_, _ != 0)
58-
case IntegerType => nullOrCast[Int](_, _ != 0)
59-
case ShortType => nullOrCast[Short](_, _ != 0)
60-
case ByteType => nullOrCast[Byte](_, _ != 0)
61-
case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
62-
case DoubleType => nullOrCast[Double](_, _ != 0)
63-
case FloatType => nullOrCast[Float](_, _ != 0)
53+
private[this] def castToBoolean: Any => Any = child.dataType match {
54+
case StringType =>
55+
buildCast[String](_, _.length() != 0)
56+
case TimestampType =>
57+
buildCast[Timestamp](_, b => b.getTime() != 0 || b.getNanos() != 0)
58+
case LongType =>
59+
buildCast[Long](_, _ != 0)
60+
case IntegerType =>
61+
buildCast[Int](_, _ != 0)
62+
case ShortType =>
63+
buildCast[Short](_, _ != 0)
64+
case ByteType =>
65+
buildCast[Byte](_, _ != 0)
66+
case DecimalType =>
67+
buildCast[BigDecimal](_, _ != 0)
68+
case DoubleType =>
69+
buildCast[Double](_, _ != 0)
70+
case FloatType =>
71+
buildCast[Float](_, _ != 0)
6472
}
6573

6674
// TimestampConverter
67-
def castToTimestamp: Any => Any = child.dataType match {
68-
case StringType => nullOrCast[String](_, s => {
69-
// Throw away extra if more than 9 decimal places
70-
val periodIdx = s.indexOf(".");
71-
var n = s
72-
if (periodIdx != -1) {
73-
if (n.length() - periodIdx > 9) {
75+
private[this] def castToTimestamp: Any => Any = child.dataType match {
76+
case StringType =>
77+
buildCast[String](_, s => {
78+
// Throw away extra if more than 9 decimal places
79+
val periodIdx = s.indexOf(".")
80+
var n = s
81+
if (periodIdx != -1 && n.length() - periodIdx > 9) {
7482
n = n.substring(0, periodIdx + 10)
7583
}
76-
}
77-
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
78-
})
79-
case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
80-
case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
81-
case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
82-
case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
83-
case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
84+
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null }
85+
})
86+
case BooleanType =>
87+
buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000))
88+
case LongType =>
89+
buildCast[Long](_, l => new Timestamp(l * 1000))
90+
case IntegerType =>
91+
buildCast[Int](_, i => new Timestamp(i * 1000))
92+
case ShortType =>
93+
buildCast[Short](_, s => new Timestamp(s * 1000))
94+
case ByteType =>
95+
buildCast[Byte](_, b => new Timestamp(b * 1000))
8496
// TimestampWritable.decimalToTimestamp
85-
case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
97+
case DecimalType =>
98+
buildCast[BigDecimal](_, d => decimalToTimestamp(d))
8699
// TimestampWritable.doubleToTimestamp
87-
case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
100+
case DoubleType =>
101+
buildCast[Double](_, d => decimalToTimestamp(d))
88102
// TimestampWritable.floatToTimestamp
89-
case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
103+
case FloatType =>
104+
buildCast[Float](_, f => decimalToTimestamp(f))
90105
}
91106

92-
private def decimalToTimestamp(d: BigDecimal) = {
107+
private[this] def decimalToTimestamp(d: BigDecimal) = {
93108
val seconds = d.longValue()
94109
val bd = (d - seconds) * 1000000000
95110
val nanos = bd.intValue()
@@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
104119
}
105120

106121
// Timestamp to long, converting milliseconds to seconds
107-
private def timestampToLong(ts: Timestamp) = ts.getTime / 1000
122+
private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000
108123

109-
private def timestampToDouble(ts: Timestamp) = {
124+
private[this] def timestampToDouble(ts: Timestamp) = {
110125
// First part is the seconds since the beginning of time, followed by nanosecs.
111126
ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
112127
}
113128

114-
def castToLong: Any => Any = child.dataType match {
115-
case StringType => nullOrCast[String](_, s => try s.toLong catch {
116-
case _: NumberFormatException => null
117-
})
118-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
119-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
120-
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
121-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
122-
}
123-
124-
def castToInt: Any => Any = child.dataType match {
125-
case StringType => nullOrCast[String](_, s => try s.toInt catch {
126-
case _: NumberFormatException => null
127-
})
128-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
129-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt)
130-
case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
131-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
132-
}
133-
134-
def castToShort: Any => Any = child.dataType match {
135-
case StringType => nullOrCast[String](_, s => try s.toShort catch {
136-
case _: NumberFormatException => null
137-
})
138-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort)
139-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
140-
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
141-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
142-
}
143-
144-
def castToByte: Any => Any = child.dataType match {
145-
case StringType => nullOrCast[String](_, s => try s.toByte catch {
146-
case _: NumberFormatException => null
147-
})
148-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte)
149-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
150-
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
151-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
152-
}
153-
154-
def castToDecimal: Any => Any = child.dataType match {
155-
case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
156-
case _: NumberFormatException => null
157-
})
158-
case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
129+
private[this] def castToLong: Any => Any = child.dataType match {
130+
case StringType =>
131+
buildCast[String](_, s => try s.toLong catch {
132+
case _: NumberFormatException => null
133+
})
134+
case BooleanType =>
135+
buildCast[Boolean](_, b => if (b) 1L else 0L)
136+
case TimestampType =>
137+
buildCast[Timestamp](_, t => timestampToLong(t))
138+
case DecimalType =>
139+
buildCast[BigDecimal](_, _.toLong)
140+
case x: NumericType =>
141+
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
142+
}
143+
144+
private[this] def castToInt: Any => Any = child.dataType match {
145+
case StringType =>
146+
buildCast[String](_, s => try s.toInt catch {
147+
case _: NumberFormatException => null
148+
})
149+
case BooleanType =>
150+
buildCast[Boolean](_, b => if (b) 1 else 0)
151+
case TimestampType =>
152+
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
153+
case DecimalType =>
154+
buildCast[BigDecimal](_, _.toInt)
155+
case x: NumericType =>
156+
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
157+
}
158+
159+
private[this] def castToShort: Any => Any = child.dataType match {
160+
case StringType =>
161+
buildCast[String](_, s => try s.toShort catch {
162+
case _: NumberFormatException => null
163+
})
164+
case BooleanType =>
165+
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
166+
case TimestampType =>
167+
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
168+
case DecimalType =>
169+
buildCast[BigDecimal](_, _.toShort)
170+
case x: NumericType =>
171+
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
172+
}
173+
174+
private[this] def castToByte: Any => Any = child.dataType match {
175+
case StringType =>
176+
buildCast[String](_, s => try s.toByte catch {
177+
case _: NumberFormatException => null
178+
})
179+
case BooleanType =>
180+
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
181+
case TimestampType =>
182+
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
183+
case DecimalType =>
184+
buildCast[BigDecimal](_, _.toByte)
185+
case x: NumericType =>
186+
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
187+
}
188+
189+
private[this] def castToDecimal: Any => Any = child.dataType match {
190+
case StringType =>
191+
buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
192+
case _: NumberFormatException => null
193+
})
194+
case BooleanType =>
195+
buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
159196
case TimestampType =>
160197
// Note that we lose precision here.
161-
nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
162-
case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
163-
}
164-
165-
def castToDouble: Any => Any = child.dataType match {
166-
case StringType => nullOrCast[String](_, s => try s.toDouble catch {
167-
case _: NumberFormatException => null
168-
})
169-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
170-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
171-
case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
172-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
173-
}
174-
175-
def castToFloat: Any => Any = child.dataType match {
176-
case StringType => nullOrCast[String](_, s => try s.toFloat catch {
177-
case _: NumberFormatException => null
178-
})
179-
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
180-
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
181-
case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
182-
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
198+
buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
199+
case x: NumericType =>
200+
b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
201+
}
202+
203+
private[this] def castToDouble: Any => Any = child.dataType match {
204+
case StringType =>
205+
buildCast[String](_, s => try s.toDouble catch {
206+
case _: NumberFormatException => null
207+
})
208+
case BooleanType =>
209+
buildCast[Boolean](_, b => if (b) 1d else 0d)
210+
case TimestampType =>
211+
buildCast[Timestamp](_, t => timestampToDouble(t))
212+
case DecimalType =>
213+
buildCast[BigDecimal](_, _.toDouble)
214+
case x: NumericType =>
215+
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
216+
}
217+
218+
private[this] def castToFloat: Any => Any = child.dataType match {
219+
case StringType =>
220+
buildCast[String](_, s => try s.toFloat catch {
221+
case _: NumberFormatException => null
222+
})
223+
case BooleanType =>
224+
buildCast[Boolean](_, b => if (b) 1f else 0f)
225+
case TimestampType =>
226+
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
227+
case DecimalType =>
228+
buildCast[BigDecimal](_, _.toFloat)
229+
case x: NumericType =>
230+
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
183231
}
184232

185-
private lazy val cast: Any => Any = dataType match {
233+
private[this] lazy val cast: Any => Any = dataType match {
186234
case StringType => castToString
187235
case BinaryType => castToBinary
188236
case DecimalType => castToDecimal
@@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
198246

199247
override def eval(input: Row): Any = {
200248
val evaluated = child.eval(input)
201-
if (evaluated == null) {
202-
null
203-
} else {
204-
cast(evaluated)
205-
}
249+
if (evaluated == null) null else cast(evaluated)
206250
}
207251
}

0 commit comments

Comments
 (0)