Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ object FunctionRegistry {
expression[BitwiseCount]("bit_count"),
expression[BitAndAgg]("bit_and"),
expression[BitOrAgg]("bit_or"),
expression[BitXorAgg]("bit_xor"),

// json
expression[StructsToJson]("to_json"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,14 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BitwiseAnd, BitwiseOr, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryArithmetic, BitwiseAnd, BitwiseOr, BitwiseXor, ExpectsInputTypes, Expression, ExpressionDescription, If, IsNull, Literal}
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.",
examples = """
Examples:
> SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col);
1
""",
since = "3.0.0")
case class BitAndAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes {
abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes {

override def nodeName: String = "bit_and"
val child: Expression

def bitOperator(left: Expression, right: Expression): BinaryArithmetic

override def children: Seq[Expression] = child :: Nil

Expand All @@ -40,23 +34,40 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect

override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)

private lazy val bitAnd = AttributeReference("bit_and", child.dataType)()

override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAnd :: Nil
private lazy val bitAgg = AttributeReference(nodeName, child.dataType)()

override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil

override lazy val aggBufferAttributes: Seq[AttributeReference] = bitAgg :: Nil

override lazy val evaluateExpression: AttributeReference = bitAgg

override lazy val updateExpressions: Seq[Expression] =
If(IsNull(bitAnd),
If(IsNull(bitAgg),
child,
If(IsNull(child), bitAnd, BitwiseAnd(bitAnd, child))) :: Nil
If(IsNull(child), bitAgg, bitOperator(bitAgg, child))) :: Nil

override lazy val mergeExpressions: Seq[Expression] =
If(IsNull(bitAnd.left),
bitAnd.right,
If(IsNull(bitAnd.right), bitAnd.left, BitwiseAnd(bitAnd.left, bitAnd.right))) :: Nil
If(IsNull(bitAgg.left),
bitAgg.right,
If(IsNull(bitAgg.right), bitAgg.left, bitOperator(bitAgg.left, bitAgg.right))) :: Nil
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the bitwise AND of all non-null input values, or null if none.",
examples = """
Examples:
> SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col);
1
""",
since = "3.0.0")
case class BitAndAgg(child: Expression) extends BitAggregate {

override lazy val evaluateExpression: AttributeReference = bitAnd
override def nodeName: String = "bit_and"

override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
BitwiseAnd(left, right)
}
}

@ExpressionDescription(
Expand All @@ -67,33 +78,28 @@ case class BitAndAgg(child: Expression) extends DeclarativeAggregate with Expect
7
""",
since = "3.0.0")
case class BitOrAgg(child: Expression) extends DeclarativeAggregate with ExpectsInputTypes {
case class BitOrAgg(child: Expression) extends BitAggregate {

override def nodeName: String = "bit_or"

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

override def dataType: DataType = child.dataType

override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)

private lazy val bitOr = AttributeReference("bit_or", child.dataType)()

override lazy val aggBufferAttributes: Seq[AttributeReference] = bitOr :: Nil

override lazy val initialValues: Seq[Literal] = Literal.create(null, dataType) :: Nil
override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
BitwiseOr(left, right)
}
}

override lazy val updateExpressions: Seq[Expression] =
If(IsNull(bitOr),
child,
If(IsNull(child), bitOr, BitwiseOr(bitOr, child))) :: Nil
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the bitwise XOR of all non-null input values, or null if none.",
examples = """
Examples:
> SELECT _FUNC_(col) FROM VALUES (3), (5) AS tab(col);
6
""",
since = "3.0.0")
case class BitXorAgg(child: Expression) extends BitAggregate {

override lazy val mergeExpressions: Seq[Expression] =
If(IsNull(bitOr.left),
bitOr.right,
If(IsNull(bitOr.right), bitOr.left, BitwiseOr(bitOr.left, bitOr.right))) :: Nil
override def nodeName: String = "bit_xor"

override lazy val evaluateExpression: AttributeReference = bitOr
override def bitOperator(left: Expression, right: Expression): BinaryArithmetic = {
BitwiseXor(left, right)
}
}
31 changes: 31 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/bitwise.sql
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,34 @@ select bit_count(-9223372036854775808L);
-- other illegal arguments
select bit_count("bit count");
select bit_count('a');

-- test for bit_xor
--
CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES
(1, 1, 1, 1L),
(2, 3, 4, null),
(7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4);

-- empty case
SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0;

-- null case
SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null;

-- the suffix numbers show the expected answer
SELECT
BIT_XOR(cast(b1 as tinyint)) AS a4,
BIT_XOR(cast(b2 as smallint)) AS b5,
BIT_XOR(b3) AS c2,
BIT_XOR(b4) AS d2,
BIT_XOR(distinct b4) AS e2
FROM bitwise_test;

-- group by
SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1;

--having
SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7;

-- window
SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test;
71 changes: 70 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/bitwise.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 20
-- Number of queries: 27


-- !query 0
Expand Down Expand Up @@ -162,3 +162,72 @@ struct<>
-- !query 19 output
org.apache.spark.sql.AnalysisException
cannot resolve 'bit_count('a')' due to data type mismatch: argument 1 requires (integral or boolean) type, however, ''a'' is of string type.; line 1 pos 7


-- !query 20
CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES
(1, 1, 1, 1L),
(2, 3, 4, null),
(7, 7, 7, 3L) AS bitwise_test(b1, b2, b3, b4)
-- !query 20 schema
struct<>
-- !query 20 output



-- !query 21
SELECT BIT_XOR(b3) AS n1 FROM bitwise_test where 1 = 0
-- !query 21 schema
struct<n1:int>
-- !query 21 output
NULL


-- !query 22
SELECT BIT_XOR(b4) AS n1 FROM bitwise_test where b4 is null
-- !query 22 schema
struct<n1:bigint>
-- !query 22 output
NULL


-- !query 23
SELECT
BIT_XOR(cast(b1 as tinyint)) AS a4,
BIT_XOR(cast(b2 as smallint)) AS b5,
BIT_XOR(b3) AS c2,
BIT_XOR(b4) AS d2,
BIT_XOR(distinct b4) AS e2
FROM bitwise_test
-- !query 23 schema
struct<a4:tinyint,b5:smallint,c2:int,d2:bigint,e2:bigint>
-- !query 23 output
4 5 2 2 2


-- !query 24
SELECT bit_xor(b3) FROM bitwise_test GROUP BY b1 & 1
-- !query 24 schema
struct<bit_xor(b3):int>
-- !query 24 output
4
6


-- !query 25
SELECT b1, bit_xor(b2) FROM bitwise_test GROUP BY b1 HAVING bit_and(b2) < 7
-- !query 25 schema
struct<b1:int,bit_xor(b2):int>
-- !query 25 output
1 1
2 3


-- !query 26
SELECT b1, b2, bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2) FROM bitwise_test
-- !query 26 schema
struct<b1:int,b2:int,bit_xor(b2) OVER (PARTITION BY b1 ORDER BY b2 ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int>
-- !query 26 output
1 1 1
2 3 3
7 7 7