Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import java.math.MathContext

import scala.util.Random

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -84,6 +86,7 @@ object RandomDataGenerator {
* random data generator is defined for that data type. The generated values will use an external
* representation of the data type; for example, the random generator for [[DateType]] will return
* instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]].
* For a [[UserDefinedType]] for a class X, an instance of class X is returned.
*
* @param dataType the type to generate values for
* @param nullable whether null values should be generated
Expand All @@ -106,7 +109,22 @@ object RandomDataGenerator {
})
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
case TimestampType =>
val generator =
() => {
var milliseconds = rand.nextLong() % 253402329599999L
// -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT
// for "0001-01-01 00:00:00.000000". We need to find a
// number that is greater or equals to this number as a valid timestamp value.
while (milliseconds < -62135740800000L) {
// 253402329599999L is the the number of milliseconds since
// January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999".
milliseconds = rand.nextLong() % 253402329599999L
}
// DateTimeUtils.toJavaTimestamp takes microsecond.
DateTimeUtils.toJavaTimestamp(milliseconds * 1000)
}
Some(generator)
case CalendarIntervalType => Some(() => {
val months = rand.nextInt(1000)
val ns = rand.nextLong()
Expand Down Expand Up @@ -159,6 +177,27 @@ object RandomDataGenerator {
None
}
}
case udt: UserDefinedType[_] => {
val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed)
// Because random data generator at here returns scala value, we need to
// convert it to catalyst value to call udt's deserialize.
val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType)

if (maybeSqlTypeGenerator.isDefined) {
val sqlTypeGenerator = maybeSqlTypeGenerator.get
val generator = () => {
val generatedScalaValue = sqlTypeGenerator.apply()
if (generatedScalaValue == null) {
null
} else {
udt.deserialize(toCatalystType(generatedScalaValue))
}
}
Some(generator)
} else {
None
}
}
case unsupportedType => None
}
// Handle nullability by wrapping the non-null value generator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils

import scala.collection.Map

Expand Down Expand Up @@ -89,7 +90,7 @@ private[sql] object JacksonGenerator {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v) => gen.writeString(v.toString)
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString)
case (IntegerType, v: Int) => gen.writeNumber(v)
case (ShortType, v: Short) => gen.writeNumber(v)
case (FloatType, v: Float) => gen.writeNumber(v)
Expand All @@ -99,8 +100,12 @@ private[sql] object JacksonGenerator {
case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
case (BooleanType, v: Boolean) => gen.writeBoolean(v)
case (DateType, v) => gen.writeString(v.toString)
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v))
case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString)
// For UDT values, they should be in the SQL type's corresponding value type.
// We should not see values in the user-defined class at here.
// For example, VectorUDT's SQL type is an array of double. So, we should expect that v is
// an ArrayData at here, instead of a Vector.
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)

case (ArrayType(ty, _), v: ArrayData) =>
gen.writeStartArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,37 @@ private[sql] object JacksonParser {
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) =>
parser.getFloatValue

case (VALUE_STRING, FloatType) =>
// Special case handling for NaN and Infinity.
val value = parser.getText
val lowerCaseValue = value.toLowerCase()
if (lowerCaseValue.equals("nan") ||
lowerCaseValue.equals("infinity") ||
lowerCaseValue.equals("-infinity") ||
lowerCaseValue.equals("inf") ||
lowerCaseValue.equals("-inf")) {
value.toFloat
} else {
sys.error(s"Cannot parse $value as FloatType.")
}

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
parser.getDoubleValue

case (VALUE_STRING, DoubleType) =>
// Special case handling for NaN and Infinity.
val value = parser.getText
val lowerCaseValue = value.toLowerCase()
if (lowerCaseValue.equals("nan") ||
lowerCaseValue.equals("infinity") ||
lowerCaseValue.equals("-infinity") ||
lowerCaseValue.equals("inf") ||
lowerCaseValue.equals("-inf")) {
value.toDouble
} else {
sys.error(s"Cannot parse $value as DoubleType.")
}

case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) =>
Decimal(parser.getDecimalValue, dt.precision, dt.scale)

Expand Down Expand Up @@ -126,6 +154,9 @@ private[sql] object JacksonParser {

case (_, udt: UserDefinedType[_]) =>
convertField(factory, parser, udt.sqlType)

case (token, dataType) =>
sys.error(s"Failed to parse a value for data type $dataType (current token: $token).")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
import sqlContext._
import sqlContext.implicits._

// ORC does not play well with NullType and UDT.
override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: NullType => false
case _: CalendarIntervalType => false
case _: UserDefinedType[_] => false
case _ => true
}

test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
val basePath = new Path(file.getCanonicalPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {

import sqlContext._

// JSON does not write data of NullType and does not play well with BinaryType.
override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: NullType => false
case _: BinaryType => false
case _: CalendarIntervalType => false
case _ => true
}

test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
val basePath = new Path(file.getCanonicalPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.{execution, AnalysisException, SaveMode}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types._


class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
Expand All @@ -33,6 +33,13 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
import sqlContext._
import sqlContext.implicits._

// Parquet does not play well with NullType.
override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: NullType => false
case _: CalendarIntervalType => false
case _ => true
}

test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
val basePath = new Path(file.getCanonicalPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,30 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types._

class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName

import sqlContext._

// We have a very limited number of supported types at here since it is just for a
// test relation and we do very basic testing at here.
override protected def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: BinaryType => false
// We are using random data generator and the generated strings are not really valid string.
case _: StringType => false
case _: BooleanType => false // see https://issues.apache.org/jira/browse/SPARK-10442
case _: CalendarIntervalType => false
case _: DateType => false
case _: TimestampType => false
case _: ArrayType => false
case _: MapType => false
case _: StructType => false
case _: UserDefinedType[_] => false
case _ => true
}

test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
val basePath = new Path(file.getCanonicalPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends
new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context)

override def write(row: Row): Unit = {
val serialized = row.toSeq.map(_.toString).mkString(",")
val serialized = row.toSeq.map { v =>
if (v == null) "" else v.toString
}.mkString(",")
recordWriter.write(null, new Text(serialized))
}

Expand Down Expand Up @@ -112,7 +114,8 @@ class SimpleTextRelation(
val fields = dataSchema.map(_.dataType)

sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record =>
Row(record.split(",").zip(fields).map { case (value, dataType) =>
Row(record.split(",", -1).zip(fields).map { case (v, dataType) =>
val value = if (v == "") null else v
// `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.)
val catalystValue = Cast(Literal(value), dataType).eval()
// Here we're converting Catalyst values to Scala values to test `needsConversion`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {

val dataSourceName: String

protected def supportsDataType(dataType: DataType): Boolean = true

val dataSchema =
StructType(
Seq(
Expand Down Expand Up @@ -100,6 +102,83 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
}
}

test("test all data types") {
withTempPath { file =>
// Create the schema.
val struct =
StructType(
StructField("f1", FloatType, true) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
// TODO: add CalendarIntervalType to here once we can save it out.
val dataTypes =
Seq(
StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct,
new MyDenseVectorUDT())
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, nullable = true)
}
val schema = StructType(fields)

// Generate data at the driver side. We need to materialize the data first and then
// create RDD.
val maybeDataGenerator =
RandomDataGenerator.forType(
dataType = schema,
nullable = true,
seed = Some(System.nanoTime()))
val dataGenerator =
maybeDataGenerator
.getOrElse(fail(s"Failed to create data generator for schema $schema"))
val data = (1 to 10).map { i =>
dataGenerator.apply() match {
case row: Row => row
case null => Row.fromSeq(Seq.fill(schema.length)(null))
case other =>
fail(s"Row or null is expected to be generated, " +
s"but a ${other.getClass.getCanonicalName} is generated.")
}
}

// Create a DF for the schema with random data.
val rdd = sqlContext.sparkContext.parallelize(data, 10)
val df = sqlContext.createDataFrame(rdd, schema)

// All columns that have supported data types of this source.
val supportedColumns = schema.fields.collect {
case StructField(name, dataType, _, _) if supportsDataType(dataType) => name
}
val selectedColumns = util.Random.shuffle(supportedColumns.toSeq)

val dfToBeSaved = df.selectExpr(selectedColumns: _*)

// Save the data out.
dfToBeSaved
.write
.format(dataSourceName)
.option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests.
.save(file.getCanonicalPath)

val loadedDF =
sqlContext
.read
.format(dataSourceName)
.schema(dfToBeSaved.schema)
.option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests.
.load(file.getCanonicalPath)
.selectExpr(selectedColumns: _*)

// Read the data back.
checkAnswer(
loadedDF,
dfToBeSaved
)
}
}

test("save()/load() - non-partitioned table - Overwrite") {
withTempPath { file =>
testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath)
Expand Down