Skip to content

Commit c10b577

Browse files
committed
optimize lit
1 parent 77b164a commit c10b577

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,17 @@ object functions {
113113
* @group normal_funcs
114114
* @since 1.3.0
115115
*/
116-
def lit(literal: Any): Column = typedLit(literal)
116+
def lit(literal: Any): Column = literal match {
117+
case c: Column => c
118+
case s: Symbol => new ColumnName(s.name)
119+
case _ =>
120+
// This is different from `typedLit`. `typedLit` calls `Literal.create` to use
121+
// `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method,
122+
// `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can
123+
// just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is
124+
// significantly better when there are many threads calling `lit` concurrently.
125+
Column(Literal(literal))
126+
}
117127

118128
/**
119129
* Creates a [[Column]] of literal value.
@@ -134,6 +144,9 @@ object functions {
134144
* The difference between this function and [[lit]] is that this function
135145
* can handle parameterized scala types e.g.: List, Seq and Map.
136146
*
147+
* Note: `typedLit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized
148+
* scala types are not used.
149+
*
137150
* @group normal_funcs
138151
* @since 3.2.0
139152
*/

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
932932
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
933933
}
934934

935+
test("SPARK-37646: lit") {
936+
assert(lit($"foo") == $"foo")
937+
assert(lit('foo) == $"foo")
938+
assert(lit(1) == Column(Literal(1)))
939+
assert(lit(null) == Column(Literal(null, NullType)))
940+
}
941+
935942
test("typedLit") {
943+
assert(typedLit($"foo") == $"foo")
944+
assert(typedLit('foo) == $"foo")
945+
assert(typedLit(1) == Column(Literal(1)))
946+
assert(typedLit[String](null) == Column(Literal(null, StringType)))
947+
936948
val df = Seq(Tuple1(0)).toDF("a")
937949
// Only check the types `lit` cannot handle
938950
checkAnswer(

0 commit comments

Comments
 (0)