Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public static int calculateBitSetWidthInBytes(int numFields) {
return ((numFields + 63)/ 64) * 8;
}

public static int calculateFixedPortionByteSize(int numFields) {
return 8 * numFields + calculateBitSetWidthInBytes(numFields);
}

/**
* Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
*/
Expand Down Expand Up @@ -596,10 +600,9 @@ public byte[] getBytes() {
public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
if (i != 0) build.append(',');
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i)));
build.append(',');
}
build.deleteCharAt(build.length() - 1);
build.append(']');
return build.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* A parent class for mutable container objects that are reused when the values are changed,
Expand Down Expand Up @@ -212,6 +211,8 @@ final class SpecificMutableRow(val values: Array[MutableValue])

def this() = this(Seq.empty)

def this(schema: StructType) = this(schema.fields.map(_.dataType))

override def numFields: Int = values.length

override def setNullAt(i: Int): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.Double.longBitsToDouble
import java.lang.Float.intBitsToFloat
import java.math.MathContext

import scala.collection.mutable
import scala.util.Random

import org.apache.spark.sql.catalyst.CatalystTypeConverters
Expand Down Expand Up @@ -74,13 +75,47 @@ object RandomDataGenerator {
* @param numFields the number of fields in this schema
* @param acceptedTypes types to draw from.
*/
def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = {
def randomSchema(rand: Random, numFields: Int, acceptedTypes: Seq[DataType]): StructType = {
StructType(Seq.tabulate(numFields) { i =>
val dt = acceptedTypes(Random.nextInt(acceptedTypes.size))
StructField("col_" + i, dt, nullable = true)
val dt = acceptedTypes(rand.nextInt(acceptedTypes.size))
StructField("col_" + i, dt, nullable = rand.nextBoolean())
})
}

/**
* Returns a random nested schema. This will randomly generate structs and arrays drawn from
* acceptedTypes.
*/
def randomNestedSchema(rand: Random, totalFields: Int, acceptedTypes: Seq[DataType]):
StructType = {
val fields = mutable.ArrayBuffer.empty[StructField]
var i = 0
var numFields = totalFields
while (numFields > 0) {
val v = rand.nextInt(3)
if (v == 0) {
// Simple type:
val dt = acceptedTypes(rand.nextInt(acceptedTypes.size))
fields += new StructField("col_" + i, dt, rand.nextBoolean())
numFields -= 1
} else if (v == 1) {
// Array
val dt = acceptedTypes(rand.nextInt(acceptedTypes.size))
fields += new StructField("col_" + i, ArrayType(dt), rand.nextBoolean())
numFields -= 1
} else {
// Struct
// TODO: do empty structs make sense?
val n = Math.max(rand.nextInt(numFields), 1)
val nested = randomNestedSchema(rand, n, acceptedTypes)
fields += new StructField("col_" + i, nested, rand.nextBoolean())
numFields -= n
}
i += 1
}
StructType(fields)
}

/**
* Returns a function which generates random values for the given [[DataType]], or `None` if no
* random data generator is defined for that data type. The generated values will use an external
Expand All @@ -90,16 +125,13 @@ object RandomDataGenerator {
*
* @param dataType the type to generate values for
* @param nullable whether null values should be generated
* @param seed an optional seed for the random number generator
* @param rand an optional random number generator
* @return a function which can be called to generate random values.
*/
def forType(
dataType: DataType,
nullable: Boolean = true,
seed: Option[Long] = None): Option[() => Any] = {
val rand = new Random()
seed.foreach(rand.setSeed)

rand: Random = new Random): Option[() => Any] = {
val valueGenerator: Option[() => Any] = dataType match {
case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN)))
case BinaryType => Some(() => {
Expand Down Expand Up @@ -165,15 +197,15 @@ object RandomDataGenerator {
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort))
case NullType => Some(() => null)
case ArrayType(elementType, containsNull) => {
forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map {
forType(elementType, nullable = containsNull, rand).map {
elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
}
case MapType(keyType, valueType, valueContainsNull) => {
for (
keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong()));
keyGenerator <- forType(keyType, nullable = false, rand);
valueGenerator <-
forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong()))
forType(valueType, nullable = valueContainsNull, rand)
) yield {
() => {
Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap
Expand All @@ -182,7 +214,7 @@ object RandomDataGenerator {
}
case StructType(fields) => {
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field =>
forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong()))
forType(field.dataType, nullable = field.nullable, rand)
}
if (maybeFieldGenerators.forall(_.isDefined)) {
val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get)
Expand All @@ -192,7 +224,7 @@ object RandomDataGenerator {
}
}
case udt: UserDefinedType[_] => {
val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, seed)
val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand)
// Because random data generator at here returns scala value, we need to
// convert it to catalyst value to call udt's deserialize.
val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType)
Expand Down Expand Up @@ -229,4 +261,40 @@ object RandomDataGenerator {
}
}
}

// Generates a random row for `schema`.
def randomRow(rand: Random, schema: StructType): Row = {
val fields = mutable.ArrayBuffer.empty[Any]
schema.fields.foreach { f =>
f.dataType match {
case ArrayType(childType, nullable) => {
val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
val arr = mutable.ArrayBuffer.empty[Any]
val n = 1// rand.nextInt(10)
var i = 0
val generator = RandomDataGenerator.forType(childType, nullable, rand)
assert(generator.isDefined, "Unsupported type")
val gen = generator.get
while (i < n) {
arr += gen()
i += 1
}
arr
}
fields += data
}
case StructType(children) => {
fields += randomRow(rand, StructType(children))
}
case _ =>
val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand)
assert(generator.isDefined, "Unsupported type")
val gen = generator.get
fields += gen()
}
}
Row.fromSeq(fields)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._
Expand All @@ -32,7 +34,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
*/
def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = {
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType)
val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse {
val generator = RandomDataGenerator.forType(dataType, nullable, new Random(33)).getOrElse {
fail(s"Random data generator was not defined for $dataType")
}
if (nullable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {

private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) {
info(s"schema size $numFields1, $numFields2")
val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes)
val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes)
val random = new Random()
val schema1 = RandomDataGenerator.randomSchema(random, numFields1, candidateTypes)
val schema2 = RandomDataGenerator.randomSchema(random, numFields2, candidateTypes)

// Create the converters needed to convert from external row to internal row and to UnsafeRows.
val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1)
Expand Down
Loading