Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,6 @@ def hash(*cols):
'uppercase. Words are delimited by whitespace.',
'lower': 'Converts a string column to lower case.',
'upper': 'Converts a string column to upper case.',
'reverse': 'Reverses the string column and returns it as a new string column.',
'ltrim': 'Trim the spaces from left end for the specified string value.',
'rtrim': 'Trim the spaces from right end for the specified string value.',
'trim': 'Trim the spaces from both ends for the specified string column.',
Expand Down Expand Up @@ -2113,6 +2112,25 @@ def sort_array(col, asc=True):
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))


@since(1.5)
@ignore_unicode_prefix
def reverse(col):
"""
Collection function: returns a reversed string or an array with reverse order of elements.

:param col: name of column or expression

>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
>>> df.select(reverse(df.data).alias('s')).collect()
[Row(s=u'LQS krapS')]
>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
>>> df.select(reverse(df.data).alias('r')).collect()
[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.reverse(_to_java_column(col)))


@since(2.3)
def map_keys(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
expression[StringReverse]("reverse"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
Expand Down Expand Up @@ -410,6 +409,7 @@ object FunctionRegistry {
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def prettyName: String = "sort_array"
}

/**
* Returns a reversed string or an array with reverse order of elements.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
> SELECT _FUNC_(array(2, 1, 4, 3));
[3, 4, 1, 2]
""",
since = "1.5.0",
note = "Reverse logic for arrays is available since 2.4.0."
)
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))

override def dataType: DataType = child.dataType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
case s: UTF8String => s.reverse()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => dataType match {
case _: StringType => stringCodeGen(ev, c)
case _: ArrayType => arrayCodeGen(ctx, ev, c)
})
}

private def stringCodeGen(ev: ExprCode, childName: String): String = {
s"${ev.value} = ($childName).reverse();"
}

private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val length = ctx.freshName("length")
val javaElementType = CodeGenerator.javaType(elementType)
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)

val initialization = if (isPrimitiveType) {
s"$childName.copy()"
} else {
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
}

val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length

val swapAssigments = if (isPrimitiveType) {
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
|boolean isNullAtL = ${ev.value}.isNullAt(l);
|if(!isNullAtK) {
| $javaElementType el = ${getCall("k")};
| if(!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| } else {
| ${ev.value}.setNullAt(k);
| }
| ${ev.value}.$setFunc(l, el);
|} else if (!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| ${ev.value}.setNullAt(l);
|}""".stripMargin
} else {
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
}

s"""
|final int $length = $childName.numElements();
|${ev.value} = $initialization;
|for(int k = 0; k < $numberOfIterations; k++) {
| int l = $length - k - 1;
| $swapAssigments
|}
""".stripMargin
}

override def prettyName: String = "reverse"
}

/**
* Checks if the array (left) has the element (right)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression)
}
}

/**
* Returns the reversed given string.
*/
@ExpressionDescription(
usage = "_FUNC_(str) - Returns the reversed given string.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
""")
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.reverse()

override def prettyName: String = "reverse"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).reverse()")
}
}

/**
* Returns a string consisting of n spaces.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}

test("Reverse") {
// Primitive-type elements
val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
val ai7 = Literal.create(null, ArrayType(IntegerType))

checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
checkEvaluation(Reverse(ai4), Seq(null, null, null))
checkEvaluation(Reverse(ai5), Seq(1))
checkEvaluation(Reverse(ai6), Seq.empty)
checkEvaluation(Reverse(ai7), null)

// Non-primitive-type elements
val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
val as5 = Literal.create(Seq("a"), ArrayType(StringType))
val as6 = Literal.create(Seq.empty, ArrayType(StringType))
val as7 = Literal.create(null, ArrayType(StringType))
val aa = Literal.create(
Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
ArrayType(ArrayType(StringType)))

checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
checkEvaluation(Reverse(as4), Seq(null, null, null))
checkEvaluation(Reverse(as5), Seq("a"))
checkEvaluation(Reverse(as6), Seq.empty)
checkEvaluation(Reverse(as7), null)
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("REVERSE") {
val s = 'a.string.at(0)
val row1 = create_row("abccc")
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
checkEvaluation(StringReverse(s), "cccba", row1)
checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
checkEvaluation(Reverse(s), "cccba", row1)
checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
}

test("SPACE") {
Expand Down
15 changes: 7 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2464,14 +2464,6 @@ object functions {
StringRepeat(str.expr, lit(n).expr)
}

/**
* Reverses the string column and returns it as a new string column.
*
* @group string_funcs
* @since 1.5.0
*/
def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }

/**
* Trim the spaces from right end for the specified string value.
*
Expand Down Expand Up @@ -3308,6 +3300,13 @@ object functions {
*/
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }

/**
* Returns a reversed string or an array with reverse order of elements.
* @group collection_funcs
* @since 1.5.0
*/
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.selectExpr("array_max(a)"), answer)
}

test("reverse function") {
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on

// String test cases
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")

checkAnswer(
oneRowDF.select(reverse('s)),
Seq(Row("krapS"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(s)"),
Seq(Row("krapS"))
)
checkAnswer(
oneRowDF.select(reverse('i)),
Seq(Row("5123"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(i)"),
Seq(Row("5123"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(null)"),
Seq(Row(null))
)

// Array test cases (primitive-type elements)
val idf = Seq(
Seq(1, 9, 8, 7),
Seq(5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

checkAnswer(
idf.select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.filter(dummyFilter('i)).select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.selectExpr("reverse(i)"),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)
checkAnswer(
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)

// Array test cases (non-primitive-type elements)
val sdf = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

checkAnswer(
sdf.select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.filter(dummyFilter('s)).select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.selectExpr("reverse(s)"),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)
checkAnswer(
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)

// Error test cases
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(struct(1, 'a'))")
}
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(map(1, 'a'))")
}
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down