@@ -24,72 +24,87 @@ import org.apache.spark.sql.catalyst.types._
2424/** Cast the child expression to the target data type. */
2525case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression {
2626 override def foldable = child.foldable
27- def nullable = (child.dataType, dataType) match {
27+
28+ override def nullable = (child.dataType, dataType) match {
2829 case (StringType , _ : NumericType ) => true
2930 case (StringType , TimestampType ) => true
3031 case _ => child.nullable
3132 }
33+
3234 override def toString = s " CAST( $child, $dataType) "
3335
3436 type EvaluatedType = Any
3537
36- def nullOrCast [T ](a : Any , func : T => Any ): Any = if (a == null ) {
37- null
38- } else {
39- func(a.asInstanceOf [T ])
40- }
38+ // [[func]] assumes the input is no longer null because eval already does the null check.
39+ @ inline private [this ] def buildCast [T ](a : Any , func : T => Any ): Any = func(a.asInstanceOf [T ])
4140
4241 // UDFToString
43- def castToString : Any => Any = child.dataType match {
44- case BinaryType => nullOrCast [Array [Byte ]](_, new String (_, " UTF-8" ))
45- case _ => nullOrCast [Any ](_, _.toString)
42+ private [ this ] def castToString : Any => Any = child.dataType match {
43+ case BinaryType => buildCast [Array [Byte ]](_, new String (_, " UTF-8" ))
44+ case _ => buildCast [Any ](_, _.toString)
4645 }
4746
4847 // BinaryConverter
49- def castToBinary : Any => Any = child.dataType match {
50- case StringType => nullOrCast [String ](_, _.getBytes(" UTF-8" ))
48+ private [ this ] def castToBinary : Any => Any = child.dataType match {
49+ case StringType => buildCast [String ](_, _.getBytes(" UTF-8" ))
5150 }
5251
5352 // UDFToBoolean
54- def castToBoolean : Any => Any = child.dataType match {
55- case StringType => nullOrCast[String ](_, _.length() != 0 )
56- case TimestampType => nullOrCast[Timestamp ](_, b => {(b.getTime() != 0 || b.getNanos() != 0 )})
57- case LongType => nullOrCast[Long ](_, _ != 0 )
58- case IntegerType => nullOrCast[Int ](_, _ != 0 )
59- case ShortType => nullOrCast[Short ](_, _ != 0 )
60- case ByteType => nullOrCast[Byte ](_, _ != 0 )
61- case DecimalType => nullOrCast[BigDecimal ](_, _ != 0 )
62- case DoubleType => nullOrCast[Double ](_, _ != 0 )
63- case FloatType => nullOrCast[Float ](_, _ != 0 )
53+ private [this ] def castToBoolean : Any => Any = child.dataType match {
54+ case StringType =>
55+ buildCast[String ](_, _.length() != 0 )
56+ case TimestampType =>
57+ buildCast[Timestamp ](_, b => b.getTime() != 0 || b.getNanos() != 0 )
58+ case LongType =>
59+ buildCast[Long ](_, _ != 0 )
60+ case IntegerType =>
61+ buildCast[Int ](_, _ != 0 )
62+ case ShortType =>
63+ buildCast[Short ](_, _ != 0 )
64+ case ByteType =>
65+ buildCast[Byte ](_, _ != 0 )
66+ case DecimalType =>
67+ buildCast[BigDecimal ](_, _ != 0 )
68+ case DoubleType =>
69+ buildCast[Double ](_, _ != 0 )
70+ case FloatType =>
71+ buildCast[Float ](_, _ != 0 )
6472 }
6573
6674 // TimestampConverter
67- def castToTimestamp : Any => Any = child.dataType match {
68- case StringType => nullOrCast[ String ](_, s => {
69- // Throw away extra if more than 9 decimal places
70- val periodIdx = s.indexOf( " . " );
71- var n = s
72- if (periodIdx != - 1 ) {
73- if (n.length() - periodIdx > 9 ) {
75+ private [ this ] def castToTimestamp : Any => Any = child.dataType match {
76+ case StringType =>
77+ buildCast[ String ](_, s => {
78+ // Throw away extra if more than 9 decimal places
79+ val periodIdx = s.indexOf( " . " )
80+ var n = s
81+ if (periodIdx != - 1 && n.length() - periodIdx > 9 ) {
7482 n = n.substring(0 , periodIdx + 10 )
7583 }
76- }
77- try Timestamp .valueOf(n) catch { case _ : java.lang.IllegalArgumentException => null }
78- })
79- case BooleanType => nullOrCast[Boolean ](_, b => new Timestamp ((if (b) 1 else 0 ) * 1000 ))
80- case LongType => nullOrCast[Long ](_, l => new Timestamp (l * 1000 ))
81- case IntegerType => nullOrCast[Int ](_, i => new Timestamp (i * 1000 ))
82- case ShortType => nullOrCast[Short ](_, s => new Timestamp (s * 1000 ))
83- case ByteType => nullOrCast[Byte ](_, b => new Timestamp (b * 1000 ))
84+ try Timestamp .valueOf(n) catch { case _ : java.lang.IllegalArgumentException => null }
85+ })
86+ case BooleanType =>
87+ buildCast[Boolean ](_, b => new Timestamp ((if (b) 1 else 0 ) * 1000 ))
88+ case LongType =>
89+ buildCast[Long ](_, l => new Timestamp (l * 1000 ))
90+ case IntegerType =>
91+ buildCast[Int ](_, i => new Timestamp (i * 1000 ))
92+ case ShortType =>
93+ buildCast[Short ](_, s => new Timestamp (s * 1000 ))
94+ case ByteType =>
95+ buildCast[Byte ](_, b => new Timestamp (b * 1000 ))
8496 // TimestampWritable.decimalToTimestamp
85- case DecimalType => nullOrCast[BigDecimal ](_, d => decimalToTimestamp(d))
97+ case DecimalType =>
98+ buildCast[BigDecimal ](_, d => decimalToTimestamp(d))
8699 // TimestampWritable.doubleToTimestamp
87- case DoubleType => nullOrCast[Double ](_, d => decimalToTimestamp(d))
100+ case DoubleType =>
101+ buildCast[Double ](_, d => decimalToTimestamp(d))
88102 // TimestampWritable.floatToTimestamp
89- case FloatType => nullOrCast[Float ](_, f => decimalToTimestamp(f))
103+ case FloatType =>
104+ buildCast[Float ](_, f => decimalToTimestamp(f))
90105 }
91106
92- private def decimalToTimestamp (d : BigDecimal ) = {
107+ private [ this ] def decimalToTimestamp (d : BigDecimal ) = {
93108 val seconds = d.longValue()
94109 val bd = (d - seconds) * 1000000000
95110 val nanos = bd.intValue()
@@ -104,85 +119,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
104119 }
105120
106121 // Timestamp to long, converting milliseconds to seconds
107- private def timestampToLong (ts : Timestamp ) = ts.getTime / 1000
122+ private [ this ] def timestampToLong (ts : Timestamp ) = ts.getTime / 1000
108123
109- private def timestampToDouble (ts : Timestamp ) = {
124+ private [ this ] def timestampToDouble (ts : Timestamp ) = {
110125 // First part is the seconds since the beginning of time, followed by nanosecs.
111126 ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
112127 }
113128
114- def castToLong : Any => Any = child.dataType match {
115- case StringType => nullOrCast[String ](_, s => try s.toLong catch {
116- case _ : NumberFormatException => null
117- })
118- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1L else 0L )
119- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t))
120- case DecimalType => nullOrCast[BigDecimal ](_, _.toLong)
121- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
122- }
123-
124- def castToInt : Any => Any = child.dataType match {
125- case StringType => nullOrCast[String ](_, s => try s.toInt catch {
126- case _ : NumberFormatException => null
127- })
128- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1 else 0 )
129- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t).toInt)
130- case DecimalType => nullOrCast[BigDecimal ](_, _.toInt)
131- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
132- }
133-
134- def castToShort : Any => Any = child.dataType match {
135- case StringType => nullOrCast[String ](_, s => try s.toShort catch {
136- case _ : NumberFormatException => null
137- })
138- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1 .toShort else 0 .toShort)
139- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t).toShort)
140- case DecimalType => nullOrCast[BigDecimal ](_, _.toShort)
141- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
142- }
143-
144- def castToByte : Any => Any = child.dataType match {
145- case StringType => nullOrCast[String ](_, s => try s.toByte catch {
146- case _ : NumberFormatException => null
147- })
148- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1 .toByte else 0 .toByte)
149- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t).toByte)
150- case DecimalType => nullOrCast[BigDecimal ](_, _.toByte)
151- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
152- }
153-
154- def castToDecimal : Any => Any = child.dataType match {
155- case StringType => nullOrCast[String ](_, s => try BigDecimal (s.toDouble) catch {
156- case _ : NumberFormatException => null
157- })
158- case BooleanType => nullOrCast[Boolean ](_, b => if (b) BigDecimal (1 ) else BigDecimal (0 ))
129+ private [this ] def castToLong : Any => Any = child.dataType match {
130+ case StringType =>
131+ buildCast[String ](_, s => try s.toLong catch {
132+ case _ : NumberFormatException => null
133+ })
134+ case BooleanType =>
135+ buildCast[Boolean ](_, b => if (b) 1L else 0L )
136+ case TimestampType =>
137+ buildCast[Timestamp ](_, t => timestampToLong(t))
138+ case DecimalType =>
139+ buildCast[BigDecimal ](_, _.toLong)
140+ case x : NumericType =>
141+ b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
142+ }
143+
144+ private [this ] def castToInt : Any => Any = child.dataType match {
145+ case StringType =>
146+ buildCast[String ](_, s => try s.toInt catch {
147+ case _ : NumberFormatException => null
148+ })
149+ case BooleanType =>
150+ buildCast[Boolean ](_, b => if (b) 1 else 0 )
151+ case TimestampType =>
152+ buildCast[Timestamp ](_, t => timestampToLong(t).toInt)
153+ case DecimalType =>
154+ buildCast[BigDecimal ](_, _.toInt)
155+ case x : NumericType =>
156+ b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
157+ }
158+
159+ private [this ] def castToShort : Any => Any = child.dataType match {
160+ case StringType =>
161+ buildCast[String ](_, s => try s.toShort catch {
162+ case _ : NumberFormatException => null
163+ })
164+ case BooleanType =>
165+ buildCast[Boolean ](_, b => if (b) 1 .toShort else 0 .toShort)
166+ case TimestampType =>
167+ buildCast[Timestamp ](_, t => timestampToLong(t).toShort)
168+ case DecimalType =>
169+ buildCast[BigDecimal ](_, _.toShort)
170+ case x : NumericType =>
171+ b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
172+ }
173+
174+ private [this ] def castToByte : Any => Any = child.dataType match {
175+ case StringType =>
176+ buildCast[String ](_, s => try s.toByte catch {
177+ case _ : NumberFormatException => null
178+ })
179+ case BooleanType =>
180+ buildCast[Boolean ](_, b => if (b) 1 .toByte else 0 .toByte)
181+ case TimestampType =>
182+ buildCast[Timestamp ](_, t => timestampToLong(t).toByte)
183+ case DecimalType =>
184+ buildCast[BigDecimal ](_, _.toByte)
185+ case x : NumericType =>
186+ b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
187+ }
188+
189+ private [this ] def castToDecimal : Any => Any = child.dataType match {
190+ case StringType =>
191+ buildCast[String ](_, s => try BigDecimal (s.toDouble) catch {
192+ case _ : NumberFormatException => null
193+ })
194+ case BooleanType =>
195+ buildCast[Boolean ](_, b => if (b) BigDecimal (1 ) else BigDecimal (0 ))
159196 case TimestampType =>
160197 // Note that we lose precision here.
161- nullOrCast[Timestamp ](_, t => BigDecimal (timestampToDouble(t)))
162- case x : NumericType => b => BigDecimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b))
163- }
164-
165- def castToDouble : Any => Any = child.dataType match {
166- case StringType => nullOrCast[String ](_, s => try s.toDouble catch {
167- case _ : NumberFormatException => null
168- })
169- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1d else 0d )
170- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToDouble(t))
171- case DecimalType => nullOrCast[BigDecimal ](_, _.toDouble)
172- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
173- }
174-
175- def castToFloat : Any => Any = child.dataType match {
176- case StringType => nullOrCast[String ](_, s => try s.toFloat catch {
177- case _ : NumberFormatException => null
178- })
179- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1f else 0f )
180- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
181- case DecimalType => nullOrCast[BigDecimal ](_, _.toFloat)
182- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toFloat(b)
198+ buildCast[Timestamp ](_, t => BigDecimal (timestampToDouble(t)))
199+ case x : NumericType =>
200+ b => BigDecimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b))
201+ }
202+
203+ private [this ] def castToDouble : Any => Any = child.dataType match {
204+ case StringType =>
205+ buildCast[String ](_, s => try s.toDouble catch {
206+ case _ : NumberFormatException => null
207+ })
208+ case BooleanType =>
209+ buildCast[Boolean ](_, b => if (b) 1d else 0d )
210+ case TimestampType =>
211+ buildCast[Timestamp ](_, t => timestampToDouble(t))
212+ case DecimalType =>
213+ buildCast[BigDecimal ](_, _.toDouble)
214+ case x : NumericType =>
215+ b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
216+ }
217+
218+ private [this ] def castToFloat : Any => Any = child.dataType match {
219+ case StringType =>
220+ buildCast[String ](_, s => try s.toFloat catch {
221+ case _ : NumberFormatException => null
222+ })
223+ case BooleanType =>
224+ buildCast[Boolean ](_, b => if (b) 1f else 0f )
225+ case TimestampType =>
226+ buildCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
227+ case DecimalType =>
228+ buildCast[BigDecimal ](_, _.toFloat)
229+ case x : NumericType =>
230+ b => x.numeric.asInstanceOf [Numeric [Any ]].toFloat(b)
183231 }
184232
185- private lazy val cast : Any => Any = dataType match {
233+ private [ this ] lazy val cast : Any => Any = dataType match {
186234 case StringType => castToString
187235 case BinaryType => castToBinary
188236 case DecimalType => castToDecimal
@@ -198,10 +246,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
198246
199247 override def eval (input : Row ): Any = {
200248 val evaluated = child.eval(input)
201- if (evaluated == null ) {
202- null
203- } else {
204- cast(evaluated)
205- }
249+ if (evaluated == null ) null else cast(evaluated)
206250 }
207251}
0 commit comments