Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

Expand Down Expand Up @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, _.numBytes() != 0)
buildCast[UTF8String](_, s => {
if (StringUtils.isTrueString(s)) {
true
} else if (StringUtils.isFalseString(s)) {
false
} else {
null
}
})
case TimestampType =>
buildCast[Long](_, t => t != 0)
case DateType =>
Expand Down Expand Up @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)

private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
(c, evPrim, evNull) =>
s"""
if ($stringUtils.isTrueString($c)) {
$evPrim = true;
} else if ($stringUtils.isFalseString($c)) {
$evPrim = false;
} else {
$evNull = true;
}
"""
case TimestampType =>
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
case DateType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util

import java.util.regex.Pattern

import org.apache.spark.unsafe.types.UTF8String

object StringUtils {

// replace the _ with .{1} exactly match 1 time of any character
Expand All @@ -44,4 +46,10 @@ object StringUtils {
v
}
}

private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)

def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
}
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("cast from array") {
val array = Literal.create(Seq("123", "abc", "", null),
val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
val array_notNull = Literal.create(Seq("123", "abc", ""),
val array_notNull = Literal.create(Seq("123", "true", "f"),
ArrayType(StringType, containsNull = false))

checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
Expand All @@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Seq(true, true, false, null))
checkEvaluation(ret, Seq(null, true, false, null))
}
{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
Expand All @@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Seq(true, true, false))
checkEvaluation(ret, Seq(null, true, false))
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === true)
checkEvaluation(ret, Seq(true, true, false))
checkEvaluation(ret, Seq(null, true, false))
}

{
Expand All @@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {

test("cast from map") {
val map = Literal.create(
Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val map_notNull = Literal.create(
Map("a" -> "123", "b" -> "abc", "c" -> ""),
Map("a" -> "123", "b" -> "true", "c" -> "f"),
MapType(StringType, StringType, valueContainsNull = false))

checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
Expand All @@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
}
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
Expand All @@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
assert(ret.resolved === true)
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
Expand All @@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("abc"),
UTF8String.fromString(""),
UTF8String.fromString("true"),
UTF8String.fromString("f"),
null),
StructType(Seq(
StructField("a", StringType, nullable = true),
Expand All @@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct_notNull = Literal.create(
InternalRow(
UTF8String.fromString("123"),
UTF8String.fromString("abc"),
UTF8String.fromString("")),
UTF8String.fromString("true"),
UTF8String.fromString("f")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
Expand Down Expand Up @@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("c", BooleanType, nullable = true),
StructField("d", BooleanType, nullable = true))))
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(true, true, false, null))
checkEvaluation(ret, InternalRow(null, true, false, null))
}
{
val ret = cast(struct, StructType(Seq(
Expand Down Expand Up @@ -704,15 +704,15 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = true))))
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(true, true, false))
checkEvaluation(ret, InternalRow(null, true, false))
}
{
val ret = cast(struct_notNull, StructType(Seq(
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
assert(ret.resolved === true)
checkEvaluation(ret, InternalRow(true, true, false))
checkEvaluation(ret, InternalRow(null, true, false))
}

{
Expand All @@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
Row(
Seq("123", "abc", ""),
Map("a" ->"123", "b" -> "abc", "c" -> ""),
Seq("123", "true", "f"),
Map("a" ->"123", "b" -> "true", "c" -> "f"),
Row(0)),
StructType(Seq(
StructField("a",
Expand All @@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved === true)
checkEvaluation(ret, Row(
Seq(123, null, null),
Map("a" -> true, "b" -> true, "c" -> false),
Map("a" -> null, "b" -> true, "c" -> false),
Row(0L)))
}

test("case between string and interval") {
test("cast between string and interval") {
import org.apache.spark.unsafe.types.CalendarInterval

checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
Expand All @@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StringType),
"interval 1 years 3 months -3 days")
}

test("cast string to boolean") {
checkCast("t", true)
checkCast("true", true)
checkCast("tRUe", true)
checkCast("y", true)
checkCast("yes", true)
checkCast("1", true)

checkCast("f", false)
checkCast("false", false)
checkCast("FAlsE", false)
checkCast("n", false)
checkCast("no", false)
checkCast("0", false)

checkEvaluation(cast("abc", BooleanType), null)
checkEvaluation(cast("", BooleanType), null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}

test("saveAsTable()/load() - partitioned table - boolean type") {
sqlContext.range(2)
.select('id, ('id % 2 === 0).as("b"))
.write.partitionBy("b").saveAsTable("t")

withTable("t") {
checkAnswer(
sqlContext.table("t").sort('id),
Row(0, true) :: Row(1, false) :: Nil
)
}
}

test("saveAsTable()/load() - partitioned table - Overwrite") {
partitionedTestDF.write
.format(dataSourceName)
Expand Down