diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index cc05828cfcccb..6a945173803b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -83,9 +83,12 @@ trait ExpressionWithRandomSeed { """, since = "1.5.0") // scalastyle:on line.size.limit -case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed { +case class Rand(child: Expression, hideSeed: Boolean = false) + extends RDG with ExpressionWithRandomSeed { - def this() = this(Literal(Utils.random.nextLong(), LongType)) + def this() = this(Literal(Utils.random.nextLong(), LongType), true) + + def this(child: Expression) = this(child, false) override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType)) @@ -101,7 +104,12 @@ case class Rand(child: Expression) extends RDG with ExpressionWithRandomSeed { isNull = FalseLiteral) } - override def freshCopy(): Rand = Rand(child) + override def freshCopy(): Rand = Rand(child, hideSeed) + + override def flatArguments: Iterator[Any] = Iterator(child) + override def sql: String = { + s"rand(${if (hideSeed) "" else child.sql})" + } } object Rand { @@ -126,9 +134,12 @@ object Rand { """, since = "1.5.0") // scalastyle:on line.size.limit -case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed { +case class Randn(child: Expression, hideSeed: Boolean = false) + extends RDG with ExpressionWithRandomSeed { - def this() = this(Literal(Utils.random.nextLong(), LongType)) + def this() = this(Literal(Utils.random.nextLong(), LongType), true) + + def this(child: Expression) = this(child, false) override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType)) @@ -144,7 +155,12 @@ case class Randn(child: Expression) extends RDG with ExpressionWithRandomSeed { isNull = FalseLiteral) } - override def freshCopy(): Randn = Randn(child) + override def freshCopy(): Randn = Randn(child, hideSeed) + + override def flatArguments: Iterator[Any] = Iterator(child) + override def sql: String = { + s"randn(${if (hideSeed) "" else child.sql})" + } } object Randn { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 469c24b3b5f49..2aa53f581555f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -34,4 +34,11 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Rand(5419823303878592871L), 0.7145363364564755) checkEvaluation(Randn(5419823303878592871L), 0.7816815274533012) } + + test("SPARK-31594: Do not display the seed of rand/randn with no argument in output schema") { + assert(Rand(Literal(1L), true).sql === "rand()") + assert(Randn(Literal(1L), true).sql === "randn()") + assert(Rand(Literal(1L), false).sql === "rand(1L)") + assert(Randn(Literal(1L), false).sql === "randn(1L)") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a958ab8064ba9..d5247fba00283 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3425,6 +3425,29 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(SQLConf.get.getConf(SQLConf.CODEGEN_FALLBACK) === true) } } + + test("SPARK-31594: Do not display the seed of rand/randn with no argument in output schema") { + def checkIfSeedExistsInExplain(df: DataFrame): Unit = { + val output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain() + } + val projectExplainOutput = output.toString.split("\n").find(_.contains("Project")).get + assert(projectExplainOutput.matches(""".*randn?\(-?[0-9]+\).*""")) + } + val df1 = sql("SELECT rand()") + assert(df1.schema.head.name === "rand()") + checkIfSeedExistsInExplain(df1) + val df2 = sql("SELECT rand(1L)") + assert(df2.schema.head.name === "rand(1)") + checkIfSeedExistsInExplain(df2) + val df3 = sql("SELECT randn()") + assert(df3.schema.head.name === "randn()") + checkIfSeedExistsInExplain(df1) + val df4 = sql("SELECT randn(1L)") + assert(df4.schema.head.name === "randn(1)") + checkIfSeedExistsInExplain(df2) + } } case class Foo(bar: Option[String])