Skip to content

Commit d3e196c

Browse files
committed
Make reducer take generic arguments
1 parent d2de9c3 commit d3e196c

File tree

3 files changed

+67
-53
lines changed

3 files changed

+67
-53
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,46 +60,47 @@
6060
@Evolving
6161
public interface ReducibleFunction<I, O> {
6262

63-
/**
64-
* This method is for bucket functions.
65-
*
66-
* If this bucket function is 'reducible' on another bucket function,
67-
* return the {@link Reducer} function.
68-
* <p>
69-
* Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x)
70-
* <ul>
71-
* <li>thisFunction = bucket</li>
72-
* <li>otherFunction = bucket</li>
73-
* <li>thisNumBuckets = Int(4)</li>
74-
* <li>otherNumBuckets = Int(2)</li>
75-
* </ul>
76-
*
77-
* @param otherFunction the other bucket function
78-
* @param thisNumBuckets number of buckets for this bucket function
79-
* @param otherNumBuckets number of buckets for the other bucket function
80-
* @return a reduction function if it is reducible, null if not
81-
*/
82-
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction,
83-
int thisNumBuckets,
84-
int otherNumBuckets) {
85-
return reducer(otherFunction);
86-
}
63+
/**
64+
* This method is for parameterized functions.
65+
*
66+
* If this parameterized function is 'reducible' on another bucket function,
67+
* return the {@link Reducer} function.
68+
* <p>
69+
* Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x)
70+
* <ul>
71+
* <li>thisFunction = bucket</li>
72+
* <li>thisParam = Int(4)</li>
73+
* <li>otherFunction = bucket</li>
74+
* <li>otherParam = Int(2)</li>
75+
* </ul>
76+
*
77+
* @param thisParam parameter for this function
78+
* @param otherFunction the other parameterized function
79+
* @param otherParam parameter for the other function
80+
* @return a reduction function if it is reducible, null if not
81+
*/
82+
default Reducer<I, O> reducer(
83+
Object thisParam,
84+
ReducibleFunction<?, ?> otherFunction,
85+
Object otherParam) {
86+
throw new UnsupportedOperationException();
87+
}
8788

88-
/**
89-
* This method is for all other functions.
90-
*
91-
* If this function is 'reducible' on another function, return the {@link Reducer} function.
92-
* <p>
93-
* Example of reducing f_source = days(x) on f_target = hours(x)
94-
* <ul>
95-
* <li>thisFunction = days</li>
96-
* <li>otherFunction = hours</li>
97-
* </ul>
98-
*
99-
* @param otherFunction the other function
100-
* @return a reduction function if it is reducible, null if not.
101-
*/
102-
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
103-
return reducer(otherFunction, 0, 0);
104-
}
89+
/**
90+
* This method is for all other functions.
91+
*
92+
* If this function is 'reducible' on another function, return the {@link Reducer} function.
93+
* <p>
94+
* Example of reducing f_source = days(x) on f_target = hours(x)
95+
* <ul>
96+
* <li>thisFunction = days</li>
97+
* <li>otherFunction = hours</li>
98+
* </ul>
99+
*
100+
* @param otherFunction the other function
101+
* @return a reduction function if it is reducible, null if not.
102+
*/
103+
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
104+
throw new UnsupportedOperationException();
105+
}
105106
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ case class TransformExpression(
7070
} else {
7171
(function, other.function) match {
7272
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
73-
val reducer = f.reducer(o, numBucketsOpt.getOrElse(0), other.numBucketsOpt.getOrElse(0))
74-
val otherReducer =
75-
o.reducer(f, other.numBucketsOpt.getOrElse(0), numBucketsOpt.getOrElse(0))
76-
reducer != null || otherReducer != null
73+
val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
74+
val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
75+
thisReducer.isDefined || otherReducer.isDefined
7776
case _ => false
7877
}
7978
}
@@ -91,14 +90,24 @@ case class TransformExpression(
9190
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
9291
(function, other.function) match {
9392
case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
94-
val reducer = e1.reducer(e2,
95-
numBucketsOpt.getOrElse(0),
96-
other.numBucketsOpt.getOrElse(0))
97-
Option(reducer)
93+
reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
9894
case _ => None
9995
}
10096
}
10197

98+
// Return a Reducer for a reducible function on another reducible function
99+
private def reducer(thisFunction: ReducibleFunction[_, _],
100+
thisNumBucketsOpt: Option[Int],
101+
otherFunction: ReducibleFunction[_, _],
102+
otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
103+
val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
104+
case (Some(numBuckets), Some(otherNumBuckets)) =>
105+
thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
106+
case _ => thisFunction.reducer(otherFunction)
107+
}
108+
Option(res)
109+
}
110+
102111
override def dataType: DataType = function.resultType()
103112

104113
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =

sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,15 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In
8686
(input.getLong(1) % input.getInt(0)).toInt
8787
}
8888

89-
override def reducer(func: ReducibleFunction[_, _],
90-
thisNumBuckets: Int,
91-
otherNumBuckets: Int): Reducer[Int, Int] = {
89+
override def reducer(
90+
thisNumBucketsArg: Object,
91+
otherFunc: ReducibleFunction[_, _],
92+
otherNumBucketsArg: Object): Reducer[Int, Int] = {
9293

93-
if (func == BucketFunction) {
94+
val thisNumBuckets = thisNumBucketsArg.asInstanceOf[Int]
95+
val otherNumBuckets = otherNumBucketsArg.asInstanceOf[Int]
96+
97+
if (otherFunc == BucketFunction) {
9498
if ((thisNumBuckets > otherNumBuckets)
9599
&& (thisNumBuckets % otherNumBuckets == 0)) {
96100
BucketReducer(thisNumBuckets, otherNumBuckets)

0 commit comments

Comments
 (0)