|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql |
| 19 | + |
| 20 | +import java.lang.Double.longBitsToDouble |
| 21 | +import java.lang.Float.intBitsToFloat |
| 22 | +import java.math.MathContext |
| 23 | + |
| 24 | +import scala.util.Random |
| 25 | + |
| 26 | +import org.apache.spark.sql.types._ |
| 27 | + |
| 28 | +/** |
| 29 | + * Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random |
| 30 | + * values; instead, they're biased to return "interesting" values (such as maximum / minimum values) |
| 31 | + * with higher probability. |
| 32 | + */ |
| 33 | +object RandomDataGenerator { |
| 34 | + |
| 35 | + /** |
| 36 | + * The conditional probability of a non-null value being drawn from a set of "interesting" values |
| 37 | + * instead of being chosen uniformly at random. |
| 38 | + */ |
| 39 | + private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.5f |
| 40 | + |
| 41 | + /** |
| 42 | + * The probability of the generated value being null |
| 43 | + */ |
| 44 | + private val PROBABILITY_OF_NULL: Float = 0.1f |
| 45 | + |
| 46 | + private val MAX_STR_LEN: Int = 1024 |
| 47 | + private val MAX_ARR_SIZE: Int = 128 |
| 48 | + private val MAX_MAP_SIZE: Int = 128 |
| 49 | + |
| 50 | + /** |
| 51 | + * Helper function for constructing a biased random number generator which returns "interesting" |
| 52 | + * values with a higher probability. |
| 53 | + */ |
| 54 | + private def randomNumeric[T]( |
| 55 | + rand: Random, |
| 56 | + uniformRand: Random => T, |
| 57 | + interestingValues: Seq[T]): Some[() => T] = { |
| 58 | + val f = () => { |
| 59 | + if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) { |
| 60 | + interestingValues(rand.nextInt(interestingValues.length)) |
| 61 | + } else { |
| 62 | + uniformRand(rand) |
| 63 | + } |
| 64 | + } |
| 65 | + Some(f) |
| 66 | + } |
| 67 | + |
| 68 | + /** |
| 69 | + * Returns a function which generates random values for the given [[DataType]], or `None` if no |
| 70 | + * random data generator is defined for that data type. The generated values will use an external |
| 71 | + * representation of the data type; for example, the random generator for [[DateType]] will return |
| 72 | + * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a |
| 73 | + * [[org.apache.spark.Row]]. |
| 74 | + * |
| 75 | + * @param dataType the type to generate values for |
| 76 | + * @param nullable whether null values should be generated |
| 77 | + * @param seed an optional seed for the random number generator |
| 78 | + * @return a function which can be called to generate random values. |
| 79 | + */ |
| 80 | + def forType( |
| 81 | + dataType: DataType, |
| 82 | + nullable: Boolean = true, |
| 83 | + seed: Option[Long] = None): Option[() => Any] = { |
| 84 | + val rand = new Random() |
| 85 | + seed.foreach(rand.setSeed) |
| 86 | + |
| 87 | + val valueGenerator: Option[() => Any] = dataType match { |
| 88 | + case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) |
| 89 | + case BinaryType => Some(() => { |
| 90 | + val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN)) |
| 91 | + rand.nextBytes(arr) |
| 92 | + arr |
| 93 | + }) |
| 94 | + case BooleanType => Some(() => rand.nextBoolean()) |
| 95 | + case DateType => Some(() => new java.sql.Date(rand.nextInt())) |
| 96 | + case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) |
| 97 | + case DecimalType.Unlimited => Some( |
| 98 | + () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) |
| 99 | + case DoubleType => randomNumeric[Double]( |
| 100 | + rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, |
| 101 | + Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) |
| 102 | + case FloatType => randomNumeric[Float]( |
| 103 | + rand, r => intBitsToFloat(r.nextInt()), Seq(Float.MinValue, Float.MinPositiveValue, |
| 104 | + Float.MaxValue, Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) |
| 105 | + case ByteType => randomNumeric[Byte]( |
| 106 | + rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) |
| 107 | + case IntegerType => randomNumeric[Int]( |
| 108 | + rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0)) |
| 109 | + case LongType => randomNumeric[Long]( |
| 110 | + rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L)) |
| 111 | + case ShortType => randomNumeric[Short]( |
| 112 | + rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) |
| 113 | + case NullType => Some(() => null) |
| 114 | + case ArrayType(elementType, containsNull) => { |
| 115 | + forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { |
| 116 | + elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) |
| 117 | + } |
| 118 | + } |
| 119 | + case MapType(keyType, valueType, valueContainsNull) => { |
| 120 | + for ( |
| 121 | + keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); |
| 122 | + valueGenerator <- |
| 123 | + forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) |
| 124 | + ) yield { |
| 125 | + () => { |
| 126 | + Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap |
| 127 | + } |
| 128 | + } |
| 129 | + } |
| 130 | + case StructType(fields) => { |
| 131 | + val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => |
| 132 | + forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) |
| 133 | + } |
| 134 | + if (maybeFieldGenerators.forall(_.isDefined)) { |
| 135 | + val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) |
| 136 | + Some(() => Row.fromSeq(fieldGenerators.map(_.apply()))) |
| 137 | + } else { |
| 138 | + None |
| 139 | + } |
| 140 | + } |
| 141 | + case unsupportedType => None |
| 142 | + } |
| 143 | + // Handle nullability by wrapping the non-null value generator: |
| 144 | + valueGenerator.map { valueGenerator => |
| 145 | + if (nullable) { |
| 146 | + () => { |
| 147 | + if (rand.nextFloat() <= PROBABILITY_OF_NULL) { |
| 148 | + null |
| 149 | + } else { |
| 150 | + valueGenerator() |
| 151 | + } |
| 152 | + } |
| 153 | + } else { |
| 154 | + valueGenerator |
| 155 | + } |
| 156 | + } |
| 157 | + } |
| 158 | +} |
0 commit comments