Skip to content

Commit d306e60

Browse files
committed
Merge branch 'master' into issues/SPARK-2287
Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
2 parents 7de5706 + e4899a2 commit d306e60

File tree

2 files changed

+145
-39
lines changed

2 files changed

+145
-39
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,58 +30,59 @@ import org.apache.spark.sql.catalyst.types._
3030
object ScalaReflection {
3131
import scala.reflect.runtime.universe._
3232

33+
case class Schema(dataType: DataType, nullable: Boolean)
34+
3335
/** Returns a Sequence of attributes for the given case class type. */
3436
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
35-
case s: StructType =>
36-
s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)())
37+
case Schema(s: StructType, _) =>
38+
s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
3739
}
3840

39-
/** Returns a catalyst DataType for the given Scala Type using reflection. */
40-
def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T])
41+
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
42+
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
4143

42-
/** Returns a catalyst DataType for the given Scala Type using reflection. */
43-
def schemaFor(tpe: `Type`): DataType = tpe match {
44+
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
45+
def schemaFor(tpe: `Type`): Schema = tpe match {
4446
case t if t <:< typeOf[Option[_]] =>
4547
val TypeRef(_, _, Seq(optType)) = t
46-
schemaFor(optType)
48+
Schema(schemaFor(optType).dataType, nullable = true)
4749
case t if t <:< typeOf[Product] =>
4850
val formalTypeArgs = t.typeSymbol.asClass.typeParams
4951
val TypeRef(_, _, actualTypeArgs) = t
5052
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
51-
StructType(
52-
params.head.map(p =>
53-
StructField(
54-
p.name.toString,
55-
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)),
56-
nullable = true)))
53+
Schema(StructType(
54+
params.head.map { p =>
55+
val Schema(dataType, nullable) =
56+
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
57+
StructField(p.name.toString, dataType, nullable)
58+
}), nullable = true)
5759
// Need to decide if we actually need a special type here.
58-
case t if t <:< typeOf[Array[Byte]] => BinaryType
60+
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
5961
case t if t <:< typeOf[Array[_]] =>
6062
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
6163
case t if t <:< typeOf[Seq[_]] =>
6264
val TypeRef(_, _, Seq(elementType)) = t
63-
ArrayType(schemaFor(elementType))
65+
Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
6466
case t if t <:< typeOf[Map[_,_]] =>
6567
val TypeRef(_, _, Seq(keyType, valueType)) = t
66-
MapType(schemaFor(keyType), schemaFor(valueType))
67-
case t if t <:< typeOf[String] => StringType
68-
case t if t <:< typeOf[Timestamp] => TimestampType
69-
case t if t <:< typeOf[BigDecimal] => DecimalType
70-
case t if t <:< typeOf[java.lang.Integer] => IntegerType
71-
case t if t <:< typeOf[java.lang.Long] => LongType
72-
case t if t <:< typeOf[java.lang.Double] => DoubleType
73-
case t if t <:< typeOf[java.lang.Float] => FloatType
74-
case t if t <:< typeOf[java.lang.Short] => ShortType
75-
case t if t <:< typeOf[java.lang.Byte] => ByteType
76-
case t if t <:< typeOf[java.lang.Boolean] => BooleanType
77-
// TODO: The following datatypes could be marked as non-nullable.
78-
case t if t <:< definitions.IntTpe => IntegerType
79-
case t if t <:< definitions.LongTpe => LongType
80-
case t if t <:< definitions.DoubleTpe => DoubleType
81-
case t if t <:< definitions.FloatTpe => FloatType
82-
case t if t <:< definitions.ShortTpe => ShortType
83-
case t if t <:< definitions.ByteTpe => ByteType
84-
case t if t <:< definitions.BooleanTpe => BooleanType
68+
Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
69+
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
70+
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
71+
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
72+
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
73+
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
74+
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
75+
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
76+
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
77+
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
78+
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
79+
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
80+
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
81+
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
82+
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
83+
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
84+
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
85+
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
8586
}
8687

8788
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,128 @@ import org.scalatest.FunSuite
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.types._
2626

27+
case class PrimitiveData(
28+
intField: Int,
29+
longField: Long,
30+
doubleField: Double,
31+
floatField: Float,
32+
shortField: Short,
33+
byteField: Byte,
34+
booleanField: Boolean)
35+
36+
case class NullableData(
37+
intField: java.lang.Integer,
38+
longField: java.lang.Long,
39+
doubleField: java.lang.Double,
40+
floatField: java.lang.Float,
41+
shortField: java.lang.Short,
42+
byteField: java.lang.Byte,
43+
booleanField: java.lang.Boolean,
44+
stringField: String,
45+
decimalField: BigDecimal,
46+
timestampField: Timestamp,
47+
binaryField: Array[Byte])
48+
49+
case class OptionalData(
50+
intField: Option[Int],
51+
longField: Option[Long],
52+
doubleField: Option[Double],
53+
floatField: Option[Float],
54+
shortField: Option[Short],
55+
byteField: Option[Byte],
56+
booleanField: Option[Boolean])
57+
58+
case class ComplexData(
59+
arrayField: Seq[Int],
60+
mapField: Map[Int, String],
61+
structField: PrimitiveData)
62+
2763
case class GenericData[A](
2864
genericField: A)
2965

3066
class ScalaReflectionSuite extends FunSuite {
67+
import ScalaReflection._
68+
69+
test("primitive data") {
70+
val schema = schemaFor[PrimitiveData]
71+
assert(schema === Schema(
72+
StructType(Seq(
73+
StructField("intField", IntegerType, nullable = false),
74+
StructField("longField", LongType, nullable = false),
75+
StructField("doubleField", DoubleType, nullable = false),
76+
StructField("floatField", FloatType, nullable = false),
77+
StructField("shortField", ShortType, nullable = false),
78+
StructField("byteField", ByteType, nullable = false),
79+
StructField("booleanField", BooleanType, nullable = false))),
80+
nullable = true))
81+
}
82+
83+
test("nullable data") {
84+
val schema = schemaFor[NullableData]
85+
assert(schema === Schema(
86+
StructType(Seq(
87+
StructField("intField", IntegerType, nullable = true),
88+
StructField("longField", LongType, nullable = true),
89+
StructField("doubleField", DoubleType, nullable = true),
90+
StructField("floatField", FloatType, nullable = true),
91+
StructField("shortField", ShortType, nullable = true),
92+
StructField("byteField", ByteType, nullable = true),
93+
StructField("booleanField", BooleanType, nullable = true),
94+
StructField("stringField", StringType, nullable = true),
95+
StructField("decimalField", DecimalType, nullable = true),
96+
StructField("timestampField", TimestampType, nullable = true),
97+
StructField("binaryField", BinaryType, nullable = true))),
98+
nullable = true))
99+
}
100+
101+
test("optinal data") {
102+
val schema = schemaFor[OptionalData]
103+
assert(schema === Schema(
104+
StructType(Seq(
105+
StructField("intField", IntegerType, nullable = true),
106+
StructField("longField", LongType, nullable = true),
107+
StructField("doubleField", DoubleType, nullable = true),
108+
StructField("floatField", FloatType, nullable = true),
109+
StructField("shortField", ShortType, nullable = true),
110+
StructField("byteField", ByteType, nullable = true),
111+
StructField("booleanField", BooleanType, nullable = true))),
112+
nullable = true))
113+
}
114+
115+
test("complex data") {
116+
val schema = schemaFor[ComplexData]
117+
assert(schema === Schema(
118+
StructType(Seq(
119+
StructField("arrayField", ArrayType(IntegerType), nullable = true),
120+
StructField("mapField", MapType(IntegerType, StringType), nullable = true),
121+
StructField(
122+
"structField",
123+
StructType(Seq(
124+
StructField("intField", IntegerType, nullable = false),
125+
StructField("longField", LongType, nullable = false),
126+
StructField("doubleField", DoubleType, nullable = false),
127+
StructField("floatField", FloatType, nullable = false),
128+
StructField("shortField", ShortType, nullable = false),
129+
StructField("byteField", ByteType, nullable = false),
130+
StructField("booleanField", BooleanType, nullable = false))),
131+
nullable = true))),
132+
nullable = true))
133+
}
31134

32135
test("generic data") {
33136
val schema = ScalaReflection.schemaFor[GenericData[Int]]
34-
assert(schema ===
137+
assert(schema === Schema(
35138
StructType(Seq(
36-
StructField("genericField", IntegerType, nullable = true))))
139+
StructField("genericField", IntegerType, nullable = false))),
140+
nullable = true))
37141
}
38142

39143
test("tuple data") {
40144
val schema = ScalaReflection.schemaFor[(Int, String)]
41-
assert(schema ===
145+
assert(schema === Schema(
42146
StructType(Seq(
43-
StructField("_1", IntegerType, nullable = true),
44-
StructField("_2", StringType, nullable = true))))
147+
StructField("_1", IntegerType, nullable = false),
148+
StructField("_2", StringType, nullable = true))),
149+
nullable = true))
45150
}
46151
}

0 commit comments

Comments
 (0)