Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper {
// Check if the data types match.
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
// SPARK-18058: we shall not care about the nullability of columns
val widerType = TypeCoercion.findWiderTypeForTwo(
dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis)
if (widerType.isEmpty) {
if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) {
failAnalysis(
s"""
|${operator.nodeName} can only be performed on tables with the compatible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
// For each column, traverse all the values and find a common data type and nullability.
val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
val inputTypes = column.map(_.dataType)
val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion(
inputTypes, conf.caseSensitiveAnalysis)
val tpe = wideType.getOrElse {
val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
table.failAnalysis(s"incompatible types found in column $name for inline table")
}
StructField(name, tpe, nullable = column.exists(_.nullable))
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.matching.Regex

import org.apache.hadoop.fs.Path

import org.apache.spark.TaskContext
import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
Expand Down Expand Up @@ -107,13 +107,7 @@ object SQLConf {
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
def get: SQLConf = {
if (Utils.isTesting && TaskContext.get != null) {
// we're accessing it during task execution, fail.
throw new IllegalStateException("SQLConf should only be created and accessed on the driver.")
}
confGetter.get()()
}
def get: SQLConf = confGetter.get()()

val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
Expand Down Expand Up @@ -1280,6 +1274,12 @@ object SQLConf {
class SQLConf extends Serializable with Logging {
import SQLConf._

if (Utils.isTesting && SparkEnv.get != null) {
// assert that we're only accessing it on the driver.
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
"SQLConf should only be created and accessed on the driver.")
}

/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType {
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def sameType(other: DataType): Boolean =
DataType.equalsIgnoreNullability(this, other)
if (SQLConf.get.caseSensitiveAnalysis) {
DataType.equalsIgnoreNullability(this, other)
} else {
DataType.equalsIgnoreCaseAndNullability(this, other)
}

/**
* Returns the same data type but set all nullability fields are true
Expand Down Expand Up @@ -214,7 +218,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest {
}

private def checkWidenType(
widenFunc: (DataType, DataType, Boolean) => Option[DataType],
widenFunc: (DataType, DataType) => Option[DataType],
t1: DataType,
t2: DataType,
expected: Option[DataType],
isSymmetric: Boolean = true): Unit = {
var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis)
var found = widenFunc(t1, t2)
assert(found == expected,
s"Expected $expected as wider common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
if (isSymmetric) {
found = widenFunc(t2, t1, conf.caseSensitiveAnalysis)
found = widenFunc(t2, t1)
assert(found == expected,
s"Expected $expected as wider common type for $t2 and $t1, found $found")
}
Expand Down Expand Up @@ -524,29 +524,29 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._

ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))

ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeUnaryExpression(Literal.create(null, NullType)),
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}

test("cast NullType for binary operators") {
import TypeCoercionSuite._

ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))

ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}

test("coalesce casts") {
val rule = TypeCoercion.FunctionArgumentConversion(conf)
val rule = TypeCoercion.FunctionArgumentConversion

val intLit = Literal(1)
val longLit = Literal.create(1L)
Expand Down Expand Up @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("CreateArray casts") {
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
Expand All @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal("a")
Expand All @@ -626,15 +626,15 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal("a"), StringType)
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal(1)
:: Nil),
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3))
:: Literal(1).cast(DecimalType(13, 3))
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
Expand All @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest {

test("CreateMap casts") {
// type coercion for map keys
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
Expand All @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(2.0, FloatType), FloatType)
:: Literal("b")
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
Expand All @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal("b")
:: Nil))
// type coercion for map values
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2)
Expand All @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal(2)
:: Cast(Literal(3.0), StringType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
Expand All @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
:: Nil))
// type coercion for both map keys and values
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2.0)
Expand All @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest {

test("greatest/least cast") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
Expand All @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
Expand All @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
Expand All @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
:: Literal(1).cast(DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal.create(null, DecimalType(15, 0))
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
Expand All @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
:: Literal(1).cast(DecimalType(20, 5))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal.create(2L, LongType)
:: Literal(1)
:: Literal.create(null, DecimalType(10, 5))
Expand All @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("nanvl casts") {
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
}

test("type coercion for If") {
val rule = TypeCoercion.IfCoercion(conf)
val rule = TypeCoercion.IfCoercion
val intLit = Literal(1)
val doubleLit = Literal(1.0)
val trueLit = Literal.create(true, BooleanType)
Expand Down Expand Up @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("type coercion for CaseKeyWhen") {
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
)
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
Expand Down Expand Up @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest {
private val timeZoneResolver = ResolveTimeZone(new SQLConf)

private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan))
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
}

test("WidenSetOperationTypes for except and intersect") {
Expand Down Expand Up @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest {

test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
"in aggregation function like sum") {
val rules = Seq(FunctionArgumentConversion(conf), Division)
val rules = Seq(FunctionArgumentConversion, Division)
// Casts Integer to Double
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
Expand All @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("SPARK-17117 null type coercion in divide") {
val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf))
val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf))
val nullLit = Literal.create(null, NullType)
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
Expand Down
Loading