Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val literal: Parser[Literal] =
( numericLiteral
| booleanLiteral
| stringLit ^^ {case s => Literal(s, StringType) }
| NULL ^^^ Literal(null, NullType)
| stringLit ^^ {case s => Literal.create(s, StringType) }
| NULL ^^^ Literal.create(null, NullType)
)

protected lazy val booleanLiteral: Parser[Literal] =
( TRUE ^^^ Literal(true, BooleanType)
| FALSE ^^^ Literal(false, BooleanType)
( TRUE ^^^ Literal.create(true, BooleanType)
| FALSE ^^^ Literal.create(false, BooleanType)
)

protected lazy val numericLiteral: Parser[Literal] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ class Analyzer(
case x: Expression if nonSelectedGroupExprSet.contains(x) =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal(null, expr.dataType)
Literal.create(null, expr.dataType)
case x if x == g.gid =>
// replace the groupingId with concrete value (the bit mask)
Literal(bitmask, IntegerType)
Literal.create(bitmask, IntegerType)
})

result += GroupExpression(substitution)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
val stringNaN = Literal("NaN", StringType)
val stringNaN = Literal.create("NaN", StringType)

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
private var count: Long = _
private val sum = MutableLiteral(zero.eval(null), calcType)

private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType))
private def addFunction(value: Any) = Add(sum,
Cast(Literal.create(value, expr.dataType), calcType))

override def eval(input: Row): Any = {
if (count == 0L) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ object Literal {
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}

def create(v: Any, dataType: DataType): Literal = Literal(v, dataType)
}

/**
Expand All @@ -62,7 +64,10 @@ object IntegerLiteral {
}
}

case class Literal(value: Any, dataType: DataType) extends LeafExpression {
/**
* In order to do type checking, use Literal.create() instead of constructor
*/
case class Literal protected (value: Any, dataType: DataType) extends LeafExpression {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment explaining why this is protected?


override def foldable: Boolean = true
override def nullable: Boolean = value == null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand All @@ -235,36 +235,36 @@ object NullPropagation extends Rule[LogicalPlan] {
case _ => true
}
if (newChildren.length == 0) {
Literal(null, e.dataType)
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {
newChildren(0)
} else {
Coalesce(newChildren)
}

case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType)
case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType)
case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)

// Put exceptional cases above if any
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}
case e: BinaryComparison => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}
case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}
case e: StringComparison => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}
}
Expand All @@ -284,13 +284,13 @@ object ConstantFolding extends Rule[LogicalPlan] {
case l: Literal => l

// Fold expressions that are foldable.
case e if e.foldable => Literal(e.eval(null), e.dataType)
case e if e.foldable => Literal.create(e.eval(null), e.dataType)

// Fold "literal in (item1, item2, ..., literal, ...)" into true directly.
case In(Literal(v, _), list) if list.exists {
case Literal(candidate, _) if candidate == v => true
case _ => false
} => Literal(true, BooleanType)
} => Literal.create(true, BooleanType)
}
}
}
Expand Down Expand Up @@ -647,7 +647,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {

case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
Cast(
Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)),
Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ class HiveTypeCoercionSuite extends PlanTest {
ruleTest(
Coalesce(Literal(1.0)
:: Literal(1)
:: Literal(1.0, FloatType)
:: Literal.create(1.0, FloatType)
:: Nil),
Coalesce(Cast(Literal(1.0), DoubleType)
:: Cast(Literal(1), DoubleType)
:: Cast(Literal(1.0, FloatType), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
ruleTest(
Coalesce(Literal(1L)
Expand Down
Loading