diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8192350baa06..ec28d8dde38e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -113,7 +113,17 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = typedLit(literal) + def lit(literal: Any): Column = literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => + // This is different from `typedlit`. `typedlit` calls `Literal.create` to use + // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method, + // `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can + // just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is + // significantly better when there are many threads calling `lit` concurrently. + Column(Literal(literal)) + } /** * Creates a [[Column]] of literal value. @@ -134,6 +144,9 @@ object functions { * The difference between this function and [[lit]] is that this function * can handle parameterized scala types e.g.: List, Seq and Map. * + * @note `typedlit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized + * Scala types are not used. + * * @group normal_funcs * @since 3.2.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2acc4ff68796..fe56bcb99117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -932,7 +932,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) } + test("SPARK-37646: lit") { + assert(lit($"foo") == $"foo") + assert(lit('foo) == $"foo") + assert(lit(1) == Column(Literal(1))) + assert(lit(null) == Column(Literal(null, NullType))) + } + test("typedLit") { + assert(typedLit($"foo") == $"foo") + assert(typedLit('foo) == $"foo") + assert(typedLit(1) == Column(Literal(1))) + assert(typedLit[String](null) == Column(Literal(null, StringType))) + val df = Seq(Tuple1(0)).toDF("a") // Only check the types `lit` cannot handle checkAnswer(