@@ -24,72 +24,89 @@ import org.apache.spark.sql.catalyst.types._
2424/** Cast the child expression to the target data type. */
2525case class Cast (child : Expression , dataType : DataType ) extends UnaryExpression {
2626 override def foldable = child.foldable
27- def nullable = (child.dataType, dataType) match {
27+
28+ override def nullable = (child.dataType, dataType) match {
2829 case (StringType , _ : NumericType ) => true
2930 case (StringType , TimestampType ) => true
3031 case _ => child.nullable
3132 }
33+
3234 override def toString = s " CAST( $child, $dataType) "
3335
3436 type EvaluatedType = Any
3537
36- def nullOrCast [T ](a : Any , func : T => Any ): Any = if (a == null ) {
37- null
38- } else {
39- func(a.asInstanceOf [T ])
40- }
38+ // [[func]] assumes the input is no longer null because eval already does the null check.
39+ @ inline private [this ] def buildCast [T ](a : Any , func : T => Any ): Any = func(a.asInstanceOf [T ])
4140
4241 // UDFToString
43- def castToString : Any => Any = child.dataType match {
44- case BinaryType => nullOrCast [Array [Byte ]](_, new String (_, " UTF-8" ))
45- case _ => nullOrCast [Any ](_, _.toString)
42+ private [ this ] def castToString : Any => Any = child.dataType match {
43+ case BinaryType => buildCast [Array [Byte ]](_, new String (_, " UTF-8" ))
44+ case _ => buildCast [Any ](_, _.toString)
4645 }
4746
4847 // BinaryConverter
49- def castToBinary : Any => Any = child.dataType match {
50- case StringType => nullOrCast [String ](_, _.getBytes(" UTF-8" ))
48+ private [ this ] def castToBinary : Any => Any = child.dataType match {
49+ case StringType => buildCast [String ](_, _.getBytes(" UTF-8" ))
5150 }
5251
5352 // UDFToBoolean
54- def castToBoolean : Any => Any = child.dataType match {
55- case StringType => nullOrCast[String ](_, _.length() != 0 )
56- case TimestampType => nullOrCast[Timestamp ](_, b => {(b.getTime() != 0 || b.getNanos() != 0 )})
57- case LongType => nullOrCast[Long ](_, _ != 0 )
58- case IntegerType => nullOrCast[Int ](_, _ != 0 )
59- case ShortType => nullOrCast[Short ](_, _ != 0 )
60- case ByteType => nullOrCast[Byte ](_, _ != 0 )
61- case DecimalType => nullOrCast[BigDecimal ](_, _ != 0 )
62- case DoubleType => nullOrCast[Double ](_, _ != 0 )
63- case FloatType => nullOrCast[Float ](_, _ != 0 )
53+ private [this ] def castToBoolean : Any => Any = child.dataType match {
54+ case StringType =>
55+ buildCast[String ](_, _.length() != 0 )
56+ case TimestampType =>
57+ buildCast[Timestamp ](_, b => b.getTime() != 0 || b.getNanos() != 0 )
58+ case LongType =>
59+ buildCast[Long ](_, _ != 0 )
60+ case IntegerType =>
61+ buildCast[Int ](_, _ != 0 )
62+ case ShortType =>
63+ buildCast[Short ](_, _ != 0 )
64+ case ByteType =>
65+ buildCast[Byte ](_, _ != 0 )
66+ case DecimalType =>
67+ buildCast[BigDecimal ](_, _ != 0 )
68+ case DoubleType =>
69+ buildCast[Double ](_, _ != 0 )
70+ case FloatType =>
71+ buildCast[Float ](_, _ != 0 )
6472 }
6573
6674 // TimestampConverter
67- def castToTimestamp : Any => Any = child.dataType match {
68- case StringType => nullOrCast[String ](_, s => {
69- // Throw away extra if more than 9 decimal places
70- val periodIdx = s.indexOf(" ." );
71- var n = s
72- if (periodIdx != - 1 ) {
73- if (n.length() - periodIdx > 9 ) {
74- n = n.substring(0 , periodIdx + 10 )
75+ private [this ] def castToTimestamp : Any => Any = child.dataType match {
76+ case StringType =>
77+ buildCast[String ](_, s => {
78+ // Throw away extra if more than 9 decimal places
79+ val periodIdx = s.indexOf(" ." )
80+ var n = s
81+ if (periodIdx != - 1 ) {
82+ if (n.length() - periodIdx > 9 ) {
83+ n = n.substring(0 , periodIdx + 10 )
84+ }
7585 }
76- }
77- try Timestamp .valueOf(n) catch { case _ : java.lang.IllegalArgumentException => null }
78- })
79- case BooleanType => nullOrCast[Boolean ](_, b => new Timestamp ((if (b) 1 else 0 ) * 1000 ))
80- case LongType => nullOrCast[Long ](_, l => new Timestamp (l * 1000 ))
81- case IntegerType => nullOrCast[Int ](_, i => new Timestamp (i * 1000 ))
82- case ShortType => nullOrCast[Short ](_, s => new Timestamp (s * 1000 ))
83- case ByteType => nullOrCast[Byte ](_, b => new Timestamp (b * 1000 ))
86+ try Timestamp .valueOf(n) catch { case _ : java.lang.IllegalArgumentException => null }
87+ })
88+ case BooleanType =>
89+ buildCast[Boolean ](_, b => new Timestamp ((if (b) 1 else 0 ) * 1000 ))
90+ case LongType =>
91+ buildCast[Long ](_, l => new Timestamp (l * 1000 ))
92+ case IntegerType =>
93+ buildCast[Int ](_, i => new Timestamp (i * 1000 ))
94+ case ShortType =>
95+ buildCast[Short ](_, s => new Timestamp (s * 1000 ))
96+ case ByteType =>
97+ buildCast[Byte ](_, b => new Timestamp (b * 1000 ))
8498 // TimestampWritable.decimalToTimestamp
85- case DecimalType => nullOrCast[BigDecimal ](_, d => decimalToTimestamp(d))
99+ case DecimalType =>
100+ buildCast[BigDecimal ](_, d => decimalToTimestamp(d))
86101 // TimestampWritable.doubleToTimestamp
87- case DoubleType => nullOrCast[Double ](_, d => decimalToTimestamp(d))
102+ case DoubleType =>
103+ buildCast[Double ](_, d => decimalToTimestamp(d))
88104 // TimestampWritable.floatToTimestamp
89- case FloatType => nullOrCast[Float ](_, f => decimalToTimestamp(f))
105+ case FloatType =>
106+ buildCast[Float ](_, f => decimalToTimestamp(f))
90107 }
91108
92- private def decimalToTimestamp (d : BigDecimal ) = {
109+ private [ this ] def decimalToTimestamp (d : BigDecimal ) = {
93110 val seconds = d.longValue()
94111 val bd = (d - seconds) * 1000000000
95112 val nanos = bd.intValue()
@@ -104,85 +121,118 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
104121 }
105122
106123 // Timestamp to long, converting milliseconds to seconds
107- private def timestampToLong (ts : Timestamp ) = ts.getTime / 1000
124+ private [ this ] def timestampToLong (ts : Timestamp ) = ts.getTime / 1000
108125
109- private def timestampToDouble (ts : Timestamp ) = {
126+ private [ this ] def timestampToDouble (ts : Timestamp ) = {
110127 // First part is the seconds since the beginning of time, followed by nanosecs.
111128 ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
112129 }
113130
114- def castToLong : Any => Any = child.dataType match {
115- case StringType => nullOrCast[String ](_, s => try s.toLong catch {
116- case _ : NumberFormatException => null
117- })
118- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1L else 0L )
119- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t))
120- case DecimalType => nullOrCast[BigDecimal ](_, _.toLong)
121- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
122- }
123-
124- def castToInt : Any => Any = child.dataType match {
125- case StringType => nullOrCast[String ](_, s => try s.toInt catch {
126- case _ : NumberFormatException => null
127- })
128- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1 else 0 )
129- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t).toInt)
130- case DecimalType => nullOrCast[BigDecimal ](_, _.toInt)
131- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
132- }
133-
134- def castToShort : Any => Any = child.dataType match {
135- case StringType => nullOrCast[String ](_, s => try s.toShort catch {
136- case _ : NumberFormatException => null
137- })
138- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1 .toShort else 0 .toShort)
139- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t).toShort)
140- case DecimalType => nullOrCast[BigDecimal ](_, _.toShort)
141- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
142- }
143-
144- def castToByte : Any => Any = child.dataType match {
145- case StringType => nullOrCast[String ](_, s => try s.toByte catch {
146- case _ : NumberFormatException => null
147- })
148- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1 .toByte else 0 .toByte)
149- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToLong(t).toByte)
150- case DecimalType => nullOrCast[BigDecimal ](_, _.toByte)
151- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
152- }
153-
154- def castToDecimal : Any => Any = child.dataType match {
155- case StringType => nullOrCast[String ](_, s => try BigDecimal (s.toDouble) catch {
156- case _ : NumberFormatException => null
157- })
158- case BooleanType => nullOrCast[Boolean ](_, b => if (b) BigDecimal (1 ) else BigDecimal (0 ))
131+ private [this ] def castToLong : Any => Any = child.dataType match {
132+ case StringType =>
133+ buildCast[String ](_, s => try s.toLong catch {
134+ case _ : NumberFormatException => null
135+ })
136+ case BooleanType =>
137+ buildCast[Boolean ](_, b => if (b) 1L else 0L )
138+ case TimestampType =>
139+ buildCast[Timestamp ](_, t => timestampToLong(t))
140+ case DecimalType =>
141+ buildCast[BigDecimal ](_, _.toLong)
142+ case x : NumericType =>
143+ b => x.numeric.asInstanceOf [Numeric [Any ]].toLong(b)
144+ }
145+
146+ private [this ] def castToInt : Any => Any = child.dataType match {
147+ case StringType =>
148+ buildCast[String ](_, s => try s.toInt catch {
149+ case _ : NumberFormatException => null
150+ })
151+ case BooleanType =>
152+ buildCast[Boolean ](_, b => if (b) 1 else 0 )
153+ case TimestampType =>
154+ buildCast[Timestamp ](_, t => timestampToLong(t).toInt)
155+ case DecimalType =>
156+ buildCast[BigDecimal ](_, _.toInt)
157+ case x : NumericType =>
158+ b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b)
159+ }
160+
161+ private [this ] def castToShort : Any => Any = child.dataType match {
162+ case StringType =>
163+ buildCast[String ](_, s => try s.toShort catch {
164+ case _ : NumberFormatException => null
165+ })
166+ case BooleanType =>
167+ buildCast[Boolean ](_, b => if (b) 1 .toShort else 0 .toShort)
168+ case TimestampType =>
169+ buildCast[Timestamp ](_, t => timestampToLong(t).toShort)
170+ case DecimalType =>
171+ buildCast[BigDecimal ](_, _.toShort)
172+ case x : NumericType =>
173+ b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toShort
174+ }
175+
176+ private [this ] def castToByte : Any => Any = child.dataType match {
177+ case StringType =>
178+ buildCast[String ](_, s => try s.toByte catch {
179+ case _ : NumberFormatException => null
180+ })
181+ case BooleanType =>
182+ buildCast[Boolean ](_, b => if (b) 1 .toByte else 0 .toByte)
183+ case TimestampType =>
184+ buildCast[Timestamp ](_, t => timestampToLong(t).toByte)
185+ case DecimalType =>
186+ buildCast[BigDecimal ](_, _.toByte)
187+ case x : NumericType =>
188+ b => x.numeric.asInstanceOf [Numeric [Any ]].toInt(b).toByte
189+ }
190+
191+ private [this ] def castToDecimal : Any => Any = child.dataType match {
192+ case StringType =>
193+ buildCast[String ](_, s => try BigDecimal (s.toDouble) catch {
194+ case _ : NumberFormatException => null
195+ })
196+ case BooleanType =>
197+ buildCast[Boolean ](_, b => if (b) BigDecimal (1 ) else BigDecimal (0 ))
159198 case TimestampType =>
160199 // Note that we lose precision here.
161- nullOrCast[Timestamp ](_, t => BigDecimal (timestampToDouble(t)))
162- case x : NumericType => b => BigDecimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b))
163- }
164-
165- def castToDouble : Any => Any = child.dataType match {
166- case StringType => nullOrCast[String ](_, s => try s.toDouble catch {
167- case _ : NumberFormatException => null
168- })
169- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1d else 0d )
170- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToDouble(t))
171- case DecimalType => nullOrCast[BigDecimal ](_, _.toDouble)
172- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
173- }
174-
175- def castToFloat : Any => Any = child.dataType match {
176- case StringType => nullOrCast[String ](_, s => try s.toFloat catch {
177- case _ : NumberFormatException => null
178- })
179- case BooleanType => nullOrCast[Boolean ](_, b => if (b) 1f else 0f )
180- case TimestampType => nullOrCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
181- case DecimalType => nullOrCast[BigDecimal ](_, _.toFloat)
182- case x : NumericType => b => x.numeric.asInstanceOf [Numeric [Any ]].toFloat(b)
200+ buildCast[Timestamp ](_, t => BigDecimal (timestampToDouble(t)))
201+ case x : NumericType =>
202+ b => BigDecimal (x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b))
203+ }
204+
205+ private [this ] def castToDouble : Any => Any = child.dataType match {
206+ case StringType =>
207+ buildCast[String ](_, s => try s.toDouble catch {
208+ case _ : NumberFormatException => null
209+ })
210+ case BooleanType =>
211+ buildCast[Boolean ](_, b => if (b) 1d else 0d )
212+ case TimestampType =>
213+ buildCast[Timestamp ](_, t => timestampToDouble(t))
214+ case DecimalType =>
215+ buildCast[BigDecimal ](_, _.toDouble)
216+ case x : NumericType =>
217+ b => x.numeric.asInstanceOf [Numeric [Any ]].toDouble(b)
218+ }
219+
220+ private [this ] def castToFloat : Any => Any = child.dataType match {
221+ case StringType =>
222+ buildCast[String ](_, s => try s.toFloat catch {
223+ case _ : NumberFormatException => null
224+ })
225+ case BooleanType =>
226+ buildCast[Boolean ](_, b => if (b) 1f else 0f )
227+ case TimestampType =>
228+ buildCast[Timestamp ](_, t => timestampToDouble(t).toFloat)
229+ case DecimalType =>
230+ buildCast[BigDecimal ](_, _.toFloat)
231+ case x : NumericType =>
232+ b => x.numeric.asInstanceOf [Numeric [Any ]].toFloat(b)
183233 }
184234
185- private lazy val cast : Any => Any = dataType match {
235+ private [ this ] lazy val cast : Any => Any = dataType match {
186236 case StringType => castToString
187237 case BinaryType => castToBinary
188238 case DecimalType => castToDecimal
@@ -198,10 +248,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
198248
199249 override def eval (input : Row ): Any = {
200250 val evaluated = child.eval(input)
201- if (evaluated == null ) {
202- null
203- } else {
204- cast(evaluated)
205- }
251+ if (evaluated == null ) null else cast(evaluated)
206252 }
207253}
0 commit comments