@@ -194,10 +194,28 @@ class HiveTypeCoercionSuite extends PlanTest {
194194 Project (Seq (Alias (transformed, " a" )()), testRelation))
195195 }
196196
197- test(" null literals handling for binary operators" ) {
198- ruleTest(HiveTypeCoercion .WidenTypes ,
199- Add (Literal .create(null , NullType ), Literal .create(null , NullType )),
200- Add (
197+ test(" RemoveNullTypes for expressions that define ExpectsInputTypes" ) {
198+ import HiveTypeCoercionSuite ._
199+
200+ ruleTest(HiveTypeCoercion .RemoveNullTypes ,
201+ AnyTypeUnaryExpression (Literal .create(null , NullType )),
202+ AnyTypeUnaryExpression (Literal .create(null , NullType )))
203+
204+ ruleTest(HiveTypeCoercion .RemoveNullTypes ,
205+ NumericTypeUnaryExpression (Literal .create(null , NullType )),
206+ NumericTypeUnaryExpression (Cast (Literal .create(null , NullType ), DoubleType )))
207+ }
208+
209+ test(" RemoveNullTypes for binary operators" ) {
210+ import HiveTypeCoercionSuite ._
211+
212+ ruleTest(HiveTypeCoercion .RemoveNullTypes ,
213+ AnyTypeBinaryOperator (Literal .create(null , NullType ), Literal .create(null , NullType )),
214+ AnyTypeBinaryOperator (Literal .create(null , NullType ), Literal .create(null , NullType )))
215+
216+ ruleTest(HiveTypeCoercion .RemoveNullTypes ,
217+ NumericTypeBinaryOperator (Literal .create(null , NullType ), Literal .create(null , NullType )),
218+ NumericTypeBinaryOperator (
201219 Cast (Literal .create(null , NullType ), DoubleType ),
202220 Cast (Literal .create(null , NullType ), DoubleType )))
203221 }
@@ -310,3 +328,33 @@ class HiveTypeCoercionSuite extends PlanTest {
310328 )
311329 }
312330}
331+
332+
333+ object HiveTypeCoercionSuite {
334+
335+ case class AnyTypeUnaryExpression (child : Expression )
336+ extends UnaryExpression with ExpectsInputTypes {
337+ override def inputTypes : Seq [AbstractDataType ] = Seq (AnyDataType )
338+ override def dataType : DataType = NullType
339+ }
340+
341+ case class NumericTypeUnaryExpression (child : Expression )
342+ extends UnaryExpression with ExpectsInputTypes {
343+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
344+ override def dataType : DataType = NullType
345+ }
346+
347+ case class AnyTypeBinaryOperator (left : Expression , right : Expression )
348+ extends BinaryOperator with ExpectsInputTypes {
349+ override def dataType : DataType = NullType
350+ override def inputType : AbstractDataType = AnyDataType
351+ override def symbol : String = " anytype"
352+ }
353+
354+ case class NumericTypeBinaryOperator (left : Expression , right : Expression )
355+ extends BinaryOperator with ExpectsInputTypes {
356+ override def dataType : DataType = NullType
357+ override def inputType : AbstractDataType = NumericType
358+ override def symbol : String = " numerictype"
359+ }
360+ }
0 commit comments