Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ import org.apache.spark.util.Utils
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))

private val nonDecimalNameToType = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just copied from down below (originally in PrimitiveType)

(Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all)
.map(t => t.typeName -> t).toMap
}

/** Given the string representation of a type, return its DataType */
private def nameToType(name: String): DataType = {
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
name match {
case "decimal" => DecimalType.Unlimited
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
case other => nonDecimalNameToType(other)
}
}

private object JSortedObject {
def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
case JObject(seq) => Some(seq.toList.sortBy(_._1))
Expand All @@ -51,7 +66,7 @@ object DataType {
// NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
private def parseDataType(json: JValue): DataType = json match {
case JString(name) =>
PrimitiveType.nameToType(name)
nameToType(name)

case JSortedObject(
("containsNull", JBool(n)),
Expand Down Expand Up @@ -190,13 +205,11 @@ object DataType {
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
equalsIgnoreNullability(leftValueType, rightValueType)
case (StructType(leftFields), StructType(rightFields)) =>
leftFields.size == rightFields.size &&
leftFields.zip(rightFields)
.forall{
case (left, right) =>
left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType)
}
case (left, right) => left == right
leftFields.length == rightFields.length &&
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left / right were shadowing outer scope variables so i renamed them l, r

leftFields.zip(rightFields).forall { case (l, r) =>
l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
}
case (l, r) => l == r
}
}

Expand Down Expand Up @@ -225,12 +238,11 @@ object DataType {
equalsIgnoreCompatibleNullability(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.size == toFields.size &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
fromField.name == toField.name &&
(toField.nullable || !fromField.nullable) &&
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (fromField, toField) =>
fromField.name == toField.name &&
(toField.nullable || !fromField.nullable) &&
equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
Expand All @@ -256,8 +268,6 @@ abstract class DataType {
/** The default size of a value of this data type. */
def defaultSize: Int

def isPrimitive: Boolean = false

def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase

private[sql] def jsonValue: JValue = typeName
Expand Down Expand Up @@ -307,26 +317,6 @@ protected[sql] object NativeType {
}


protected[sql] trait PrimitiveType extends DataType {
override def isPrimitive: Boolean = true
}


protected[sql] object PrimitiveType {
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap

/** Given the string representation of a type, return its DataType */
private[sql] def nameToType(name: String): DataType = {
val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
name match {
case "decimal" => DecimalType.Unlimited
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
case other => nonDecimalNameToType(other)
}
}
}

protected[sql] abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
Expand All @@ -346,7 +336,7 @@ protected[sql] abstract class NativeType extends DataType {
* @group dataType
*/
@DeveloperApi
class StringType private() extends NativeType with PrimitiveType {
class StringType private() extends NativeType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
Expand All @@ -373,7 +363,7 @@ case object StringType extends StringType
* @group dataType
*/
@DeveloperApi
class BinaryType private() extends NativeType with PrimitiveType {
class BinaryType private() extends NativeType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "BinaryType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
Expand Down Expand Up @@ -407,7 +397,7 @@ case object BinaryType extends BinaryType
*@group dataType
*/
@DeveloperApi
class BooleanType private() extends NativeType with PrimitiveType {
class BooleanType private() extends NativeType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "BooleanType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
Expand Down Expand Up @@ -492,7 +482,7 @@ case object DateType extends DateType
*
* @group dataType
*/
abstract class NumericType extends NativeType with PrimitiveType {
abstract class NumericType extends NativeType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ private[sql] object CatalystConverter {
}
}
// All other primitive types use the default converter
case ctype: PrimitiveType => { // note: need the type tag here!
case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => {
// note: need the type tag here!
new CatalystPrimitiveConverter(parent, fieldIndex)
}
case _ => throw new RuntimeException(
Expand Down Expand Up @@ -324,9 +325,9 @@ private[parquet] class CatalystGroupConverter(

override def start(): Unit = {
current = ArrayBuffer.fill(size)(null)
converters.foreach {
converter => if (!converter.isPrimitive) {
converter.asInstanceOf[CatalystConverter].clearBuffer
converters.foreach { converter =>
if (!converter.isPrimitive) {
converter.asInstanceOf[CatalystConverter].clearBuffer()
}
}
}
Expand Down Expand Up @@ -612,7 +613,7 @@ private[parquet] class CatalystArrayConverter(

override def start(): Unit = {
if (!converter.isPrimitive) {
converter.asInstanceOf[CatalystConverter].clearBuffer
converter.asInstanceOf[CatalystConverter].clearBuffer()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ private[sql] case class InsertIntoParquetTable(
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)

val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) {
log.debug("Initializing MutableRowWriteSupport")
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ private[parquet] case class ParquetTypeInfo(
length: Option[Int] = None)

private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean =
classOf[PrimitiveType] isAssignableFrom ctype.getClass
def isPrimitiveType(ctype: DataType): Boolean = ctype match {
case _: NumericType | BooleanType | StringType | BinaryType => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to put all types other than MapType, ArrayType, and StructType here. Parquet uses this function to create record converters here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. CatalystGroupConverter actually can handle all types properly. Thought it only handles complex types at first.

case _: DataType => false
}

def toPrimitiveDataType(
parquetType: ParquetPrimitiveType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,12 +634,13 @@ private[sql] case class ParquetRelation2(
// before calling execute().

val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) {
log.debug("Initializing MutableRowWriteSupport")
classOf[MutableRowWriteSupport]
} else {
classOf[RowWriteSupport]
}
val writeSupport =
if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) {
log.debug("Initializing MutableRowWriteSupport")
classOf[MutableRowWriteSupport]
} else {
classOf[RowWriteSupport]
}

ParquetOutputFormat.setWriteSupportClass(job, writeSupport)

Expand Down