Skip to content

Commit 29163c5

Browse files
committed
[SPARK-7068][SQL] Remove PrimitiveType
Author: Reynold Xin <[email protected]> Closes apache#5646 from rxin/remove-primitive-type and squashes the following commits: 01b673d [Reynold Xin] [SPARK-7068][SQL] Remove PrimitiveType
1 parent 2d33323 commit 29163c5

File tree

5 files changed

+48
-54
lines changed

5 files changed

+48
-54
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ import org.apache.spark.util.Utils
4141
object DataType {
4242
def fromJson(json: String): DataType = parseDataType(parse(json))
4343

44+
private val nonDecimalNameToType = {
45+
(Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all)
46+
.map(t => t.typeName -> t).toMap
47+
}
48+
49+
/** Given the string representation of a type, return its DataType */
50+
private def nameToType(name: String): DataType = {
51+
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
52+
name match {
53+
case "decimal" => DecimalType.Unlimited
54+
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
55+
case other => nonDecimalNameToType(other)
56+
}
57+
}
58+
4459
private object JSortedObject {
4560
def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
4661
case JObject(seq) => Some(seq.toList.sortBy(_._1))
@@ -51,7 +66,7 @@ object DataType {
5166
// NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
5267
private def parseDataType(json: JValue): DataType = json match {
5368
case JString(name) =>
54-
PrimitiveType.nameToType(name)
69+
nameToType(name)
5570

5671
case JSortedObject(
5772
("containsNull", JBool(n)),
@@ -190,13 +205,11 @@ object DataType {
190205
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
191206
equalsIgnoreNullability(leftValueType, rightValueType)
192207
case (StructType(leftFields), StructType(rightFields)) =>
193-
leftFields.size == rightFields.size &&
194-
leftFields.zip(rightFields)
195-
.forall{
196-
case (left, right) =>
197-
left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType)
198-
}
199-
case (left, right) => left == right
208+
leftFields.length == rightFields.length &&
209+
leftFields.zip(rightFields).forall { case (l, r) =>
210+
l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
211+
}
212+
case (l, r) => l == r
200213
}
201214
}
202215

@@ -225,12 +238,11 @@ object DataType {
225238
equalsIgnoreCompatibleNullability(fromValue, toValue)
226239

227240
case (StructType(fromFields), StructType(toFields)) =>
228-
fromFields.size == toFields.size &&
229-
fromFields.zip(toFields).forall {
230-
case (fromField, toField) =>
231-
fromField.name == toField.name &&
232-
(toField.nullable || !fromField.nullable) &&
233-
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
241+
fromFields.length == toFields.length &&
242+
fromFields.zip(toFields).forall { case (fromField, toField) =>
243+
fromField.name == toField.name &&
244+
(toField.nullable || !fromField.nullable) &&
245+
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
234246
}
235247

236248
case (fromDataType, toDataType) => fromDataType == toDataType
@@ -256,8 +268,6 @@ abstract class DataType {
256268
/** The default size of a value of this data type. */
257269
def defaultSize: Int
258270

259-
def isPrimitive: Boolean = false
260-
261271
def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
262272

263273
private[sql] def jsonValue: JValue = typeName
@@ -307,26 +317,6 @@ protected[sql] object NativeType {
307317
}
308318

309319

310-
protected[sql] trait PrimitiveType extends DataType {
311-
override def isPrimitive: Boolean = true
312-
}
313-
314-
315-
protected[sql] object PrimitiveType {
316-
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
317-
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
318-
319-
/** Given the string representation of a type, return its DataType */
320-
private[sql] def nameToType(name: String): DataType = {
321-
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
322-
name match {
323-
case "decimal" => DecimalType.Unlimited
324-
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
325-
case other => nonDecimalNameToType(other)
326-
}
327-
}
328-
}
329-
330320
protected[sql] abstract class NativeType extends DataType {
331321
private[sql] type JvmType
332322
@transient private[sql] val tag: TypeTag[JvmType]
@@ -346,7 +336,7 @@ protected[sql] abstract class NativeType extends DataType {
346336
* @group dataType
347337
*/
348338
@DeveloperApi
349-
class StringType private() extends NativeType with PrimitiveType {
339+
class StringType private() extends NativeType {
350340
// The companion object and this class is separated so the companion object also subclasses
351341
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
352342
// Defined with a private constructor so the companion object is the only possible instantiation.
@@ -373,7 +363,7 @@ case object StringType extends StringType
373363
* @group dataType
374364
*/
375365
@DeveloperApi
376-
class BinaryType private() extends NativeType with PrimitiveType {
366+
class BinaryType private() extends NativeType {
377367
// The companion object and this class is separated so the companion object also subclasses
378368
// this type. Otherwise, the companion object would be of type "BinaryType$" in byte code.
379369
// Defined with a private constructor so the companion object is the only possible instantiation.
@@ -407,7 +397,7 @@ case object BinaryType extends BinaryType
407397
*@group dataType
408398
*/
409399
@DeveloperApi
410-
class BooleanType private() extends NativeType with PrimitiveType {
400+
class BooleanType private() extends NativeType {
411401
// The companion object and this class is separated so the companion object also subclasses
412402
// this type. Otherwise, the companion object would be of type "BooleanType$" in byte code.
413403
// Defined with a private constructor so the companion object is the only possible instantiation.
@@ -492,7 +482,7 @@ case object DateType extends DateType
492482
*
493483
* @group dataType
494484
*/
495-
abstract class NumericType extends NativeType with PrimitiveType {
485+
abstract class NumericType extends NativeType {
496486
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
497487
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
498488
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ private[sql] object CatalystConverter {
146146
}
147147
}
148148
// All other primitive types use the default converter
149-
case ctype: PrimitiveType => { // note: need the type tag here!
149+
case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => {
150+
// note: need the type tag here!
150151
new CatalystPrimitiveConverter(parent, fieldIndex)
151152
}
152153
case _ => throw new RuntimeException(
@@ -324,9 +325,9 @@ private[parquet] class CatalystGroupConverter(
324325

325326
override def start(): Unit = {
326327
current = ArrayBuffer.fill(size)(null)
327-
converters.foreach {
328-
converter => if (!converter.isPrimitive) {
329-
converter.asInstanceOf[CatalystConverter].clearBuffer
328+
converters.foreach { converter =>
329+
if (!converter.isPrimitive) {
330+
converter.asInstanceOf[CatalystConverter].clearBuffer()
330331
}
331332
}
332333
}
@@ -612,7 +613,7 @@ private[parquet] class CatalystArrayConverter(
612613

613614
override def start(): Unit = {
614615
if (!converter.isPrimitive) {
615-
converter.asInstanceOf[CatalystConverter].clearBuffer
616+
converter.asInstanceOf[CatalystConverter].clearBuffer()
616617
}
617618
}
618619

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ private[sql] case class InsertIntoParquetTable(
268268
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
269269

270270
val writeSupport =
271-
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
271+
if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) {
272272
log.debug("Initializing MutableRowWriteSupport")
273273
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
274274
} else {

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ private[parquet] case class ParquetTypeInfo(
4848
length: Option[Int] = None)
4949

5050
private[parquet] object ParquetTypesConverter extends Logging {
51-
def isPrimitiveType(ctype: DataType): Boolean =
52-
classOf[PrimitiveType] isAssignableFrom ctype.getClass
51+
def isPrimitiveType(ctype: DataType): Boolean = ctype match {
52+
case _: NumericType | BooleanType | StringType | BinaryType => true
53+
case _: DataType => false
54+
}
5355

5456
def toPrimitiveDataType(
5557
parquetType: ParquetPrimitiveType,

sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -634,12 +634,13 @@ private[sql] case class ParquetRelation2(
634634
// before calling execute().
635635

636636
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
637-
val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) {
638-
log.debug("Initializing MutableRowWriteSupport")
639-
classOf[MutableRowWriteSupport]
640-
} else {
641-
classOf[RowWriteSupport]
642-
}
637+
val writeSupport =
638+
if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) {
639+
log.debug("Initializing MutableRowWriteSupport")
640+
classOf[MutableRowWriteSupport]
641+
} else {
642+
classOf[RowWriteSupport]
643+
}
643644

644645
ParquetOutputFormat.setWriteSupportClass(job, writeSupport)
645646

0 commit comments

Comments
 (0)