Skip to content

Commit e4e8bb5

Browse files
mihailom-dbcloud-fan
authored andcommitted
[SPARK-47972][SQL] Restrict CAST expression for collations
### What changes were proposed in this pull request? Block of syntax CAST(value AS STRING COLLATE collation_name). ### Why are the changes needed? Current state of code allows for calls like CAST(1 AS STRING COLLATE UNICODE). We want to restrict CAST expression to only be able to cast to default collation string, and to only allow COLLATE expression to produce explicitly collated strings. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Test in CollationSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46474 from mihailom-db/SPARK-47972. Authored-by: Mihailo Milosevic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f9542d0 commit e4e8bb5

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,6 @@ object CollationTypeCasts extends TypeCoercionRule {
132132
def getOutputCollation(expr: Seq[Expression]): StringType = {
133133
val explicitTypes = expr.filter {
134134
case _: Collate => true
135-
case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined =>
136-
cast.dataType.isInstanceOf[StringType]
137135
case _ => false
138136
}
139137
.map(_.dataType.asInstanceOf[StringType].collationId)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser
2020
import java.util.Locale
2121
import java.util.concurrent.TimeUnit
2222

23+
import scala.collection.immutable.Seq
2324
import scala.collection.mutable.{ArrayBuffer, Set}
2425
import scala.jdk.CollectionConverters._
2526
import scala.util.{Left, Right}
@@ -2265,6 +2266,20 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
22652266
*/
22662267
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
22672268
val rawDataType = typedVisit[DataType](ctx.dataType())
2269+
ctx.dataType() match {
2270+
case context: PrimitiveDataTypeContext =>
2271+
val typeCtx = context.`type`()
2272+
if (typeCtx.start.getType == STRING) {
2273+
typeCtx.children.asScala.toSeq match {
2274+
case Seq(_, cctx: CollateClauseContext) =>
2275+
throw QueryParsingErrors.dataTypeUnsupportedError(
2276+
rawDataType.typeName,
2277+
ctx.dataType().asInstanceOf[PrimitiveDataTypeContext])
2278+
case _ =>
2279+
}
2280+
}
2281+
case _ =>
2282+
}
22682283
val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
22692284
ctx.name.getType match {
22702285
case SqlBaseParser.CAST =>
@@ -2284,6 +2299,20 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
22842299
*/
22852300
override def visitCastByColon(ctx: CastByColonContext): Expression = withOrigin(ctx) {
22862301
val rawDataType = typedVisit[DataType](ctx.dataType())
2302+
ctx.dataType() match {
2303+
case context: PrimitiveDataTypeContext =>
2304+
val typeCtx = context.`type`()
2305+
if (typeCtx.start.getType == STRING) {
2306+
typeCtx.children.asScala.toSeq match {
2307+
case Seq(_, cctx: CollateClauseContext) =>
2308+
throw QueryParsingErrors.dataTypeUnsupportedError(
2309+
rawDataType.typeName,
2310+
ctx.dataType().asInstanceOf[PrimitiveDataTypeContext])
2311+
case _ =>
2312+
}
2313+
}
2314+
case _ =>
2315+
}
22872316
val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
22882317
val cast = Cast(expression(ctx.primaryExpression), dataType)
22892318
cast.setTagValue(Cast.USER_SPECIFIED_CAST, ())

sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava
2222
import org.apache.spark.SparkException
2323
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
2424
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.parser.ParseException
2526
import org.apache.spark.sql.catalyst.util.CollationFactory
2627
import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
2728
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
@@ -830,6 +831,45 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
830831
}
831832
}
832833

834+
test("SPARK-47972: Cast expression limitation for collations") {
835+
checkError(
836+
exception = intercept[ParseException]
837+
(sql("SELECT cast(1 as string collate unicode)")),
838+
errorClass = "UNSUPPORTED_DATATYPE",
839+
parameters = Map(
840+
"typeName" -> toSQLType(StringType("UNICODE"))),
841+
context =
842+
ExpectedContext(fragment = s"cast(1 as string collate unicode)", start = 7, stop = 39)
843+
)
844+
845+
checkError(
846+
exception = intercept[ParseException]
847+
(sql("SELECT 'A' :: string collate unicode")),
848+
errorClass = "UNSUPPORTED_DATATYPE",
849+
parameters = Map(
850+
"typeName" -> toSQLType(StringType("UNICODE"))),
851+
context = ExpectedContext(fragment = s"'A' :: string collate unicode", start = 7, stop = 35)
852+
)
853+
854+
checkAnswer(sql(s"SELECT cast(1 as string)"), Seq(Row("1")))
855+
checkAnswer(sql(s"SELECT cast('A' as string)"), Seq(Row("A")))
856+
857+
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
858+
checkError(
859+
exception = intercept[ParseException]
860+
(sql("SELECT cast(1 as string collate unicode)")),
861+
errorClass = "UNSUPPORTED_DATATYPE",
862+
parameters = Map(
863+
"typeName" -> toSQLType(StringType("UNICODE"))),
864+
context =
865+
ExpectedContext(fragment = s"cast(1 as string collate unicode)", start = 7, stop = 39)
866+
)
867+
868+
checkAnswer(sql(s"SELECT cast(1 as string)"), Seq(Row("1")))
869+
checkAnswer(sql(s"SELECT collation(cast(1 as string))"), Seq(Row("UNICODE")))
870+
}
871+
}
872+
833873
test("SPARK-47431: Default collation set to UNICODE, column type test") {
834874
withTable("t") {
835875
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {

0 commit comments

Comments
 (0)