Skip to content

Commit 1821443

Browse files
committed
SPARK-27965: Add extractors for v2 catalog transforms.
1 parent aec0869 commit 1821443

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ private[sql] final case class BucketTransform(
9494
override def toString: String = describe
9595
}
9696

97+
private[sql] object BucketTransform {
98+
def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match {
99+
case NamedTransform("bucket", Seq(
100+
Lit(value: Int, IntegerType),
101+
Ref(seq: Seq[String]))) =>
102+
Some((value, FieldReference(seq)))
103+
case _ =>
104+
None
105+
}
106+
}
107+
97108
private[sql] final case class ApplyTransform(
98109
name: String,
99110
args: Seq[Expression]) extends Transform {
@@ -111,32 +122,104 @@ private[sql] final case class ApplyTransform(
111122
override def toString: String = describe
112123
}
113124

125+
/**
126+
* Convenience extractor for any Literal.
127+
*/
128+
private object Lit {
129+
def unapply[T](literal: Literal[T]): Some[(T, DataType)] = {
130+
Some((literal.value, literal.dataType))
131+
}
132+
}
133+
134+
/**
135+
* Convenience extractor for any NamedReference.
136+
*/
137+
private object Ref {
138+
def unapply(named: NamedReference): Some[Seq[String]] = {
139+
Some(named.fieldNames)
140+
}
141+
}
142+
143+
/**
144+
* Convenience extractor for any Transform.
145+
*/
146+
private object NamedTransform {
147+
def unapply(transform: Transform): Some[(String, Seq[Expression])] = {
148+
Some((transform.name, transform.arguments))
149+
}
150+
}
151+
114152
private[sql] final case class IdentityTransform(
115153
ref: NamedReference) extends SingleColumnTransform(ref) {
116154
override val name: String = "identity"
117155
override def describe: String = ref.describe
118156
}
119157

158+
private[sql] object IdentityTransform {
159+
def unapply(transform: Transform): Option[FieldReference] = transform match {
160+
case NamedTransform("identity", Seq(Ref(parts))) =>
161+
Some(FieldReference(parts))
162+
case _ =>
163+
None
164+
}
165+
}
166+
120167
private[sql] final case class YearsTransform(
121168
ref: NamedReference) extends SingleColumnTransform(ref) {
122169
override val name: String = "years"
123170
}
124171

172+
private[sql] object YearsTransform {
173+
def unapply(transform: Transform): Option[FieldReference] = transform match {
174+
case NamedTransform("years", Seq(Ref(parts))) =>
175+
Some(FieldReference(parts))
176+
case _ =>
177+
None
178+
}
179+
}
180+
125181
private[sql] final case class MonthsTransform(
126182
ref: NamedReference) extends SingleColumnTransform(ref) {
127183
override val name: String = "months"
128184
}
129185

186+
private[sql] object MonthsTransform {
187+
def unapply(transform: Transform): Option[FieldReference] = transform match {
188+
case NamedTransform("months", Seq(Ref(parts))) =>
189+
Some(FieldReference(parts))
190+
case _ =>
191+
None
192+
}
193+
}
194+
130195
private[sql] final case class DaysTransform(
131196
ref: NamedReference) extends SingleColumnTransform(ref) {
132197
override val name: String = "days"
133198
}
134199

200+
private[sql] object DaysTransform {
201+
def unapply(transform: Transform): Option[FieldReference] = transform match {
202+
case NamedTransform("days", Seq(Ref(parts))) =>
203+
Some(FieldReference(parts))
204+
case _ =>
205+
None
206+
}
207+
}
208+
135209
private[sql] final case class HoursTransform(
136210
ref: NamedReference) extends SingleColumnTransform(ref) {
137211
override val name: String = "hours"
138212
}
139213

214+
private[sql] object HoursTransform {
215+
def unapply(transform: Transform): Option[FieldReference] = transform match {
216+
case NamedTransform("hours", Seq(Ref(parts))) =>
217+
Some(FieldReference(parts))
218+
case _ =>
219+
None
220+
}
221+
}
222+
140223
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
141224
override def describe: String = {
142225
if (dataType.isInstanceOf[StringType]) {
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalog.v2.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst
22+
import org.apache.spark.sql.types.DataType
23+
24+
class TransformExtractorSuite extends SparkFunSuite {
25+
/**
26+
* Creates a Literal using an anonymous class.
27+
*/
28+
private def lit[T](literal: T): Literal[T] = new Literal[T] {
29+
override def value: T = literal
30+
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
31+
override def describe: String = literal.toString
32+
}
33+
34+
/**
35+
* Creates a NamedReference using an anonymous class.
36+
*/
37+
private def ref(names: String*): NamedReference = new NamedReference {
38+
override def fieldNames: Array[String] = names.toArray
39+
override def describe: String = names.mkString(".")
40+
}
41+
42+
/**
43+
* Creates a Transform using an anonymous class.
44+
*/
45+
def transform(func: String, ref: NamedReference): Transform = new Transform {
46+
override def name: String = func
47+
override def references: Array[NamedReference] = Array(ref)
48+
override def arguments: Array[Expression] = Array(ref)
49+
override def describe: String = ref.describe
50+
}
51+
52+
/**
53+
* Creates a bucket Transform using an anonymous class.
54+
*/
55+
def bucket(numBuckets: Int, ref: NamedReference): Transform = new Transform {
56+
override def name: String = "bucket"
57+
override def references: Array[NamedReference] = Array(ref)
58+
override def arguments: Array[Expression] = Array(lit(numBuckets), ref)
59+
override def describe: String = ref.describe
60+
}
61+
62+
test("Identity extractor") {
63+
transform("identity", ref("a", "b")) match {
64+
case IdentityTransform(FieldReference(seq)) =>
65+
assert(seq === Seq("a", "b"))
66+
case _ =>
67+
fail("Did not match IdentityTransform extractor")
68+
}
69+
70+
transform("unknown", ref("a", "b")) match {
71+
case IdentityTransform(FieldReference(_)) =>
72+
fail("Matched unknown transform")
73+
case _ =>
74+
// expected
75+
}
76+
}
77+
78+
test("Years extractor") {
79+
transform("years", ref("a", "b")) match {
80+
case YearsTransform(FieldReference(seq)) =>
81+
assert(seq === Seq("a", "b"))
82+
case _ =>
83+
fail("Did not match YearsTransform extractor")
84+
}
85+
86+
transform("unknown", ref("a", "b")) match {
87+
case YearsTransform(FieldReference(_)) =>
88+
fail("Matched unknown transform")
89+
case _ =>
90+
// expected
91+
}
92+
}
93+
94+
test("Months extractor") {
95+
transform("months", ref("a", "b")) match {
96+
case MonthsTransform(FieldReference(seq)) =>
97+
assert(seq === Seq("a", "b"))
98+
case _ =>
99+
fail("Did not match MonthsTransform extractor")
100+
}
101+
102+
transform("unknown", ref("a", "b")) match {
103+
case MonthsTransform(FieldReference(_)) =>
104+
fail("Matched unknown transform")
105+
case _ =>
106+
// expected
107+
}
108+
}
109+
110+
test("Days extractor") {
111+
transform("days", ref("a", "b")) match {
112+
case DaysTransform(FieldReference(seq)) =>
113+
assert(seq === Seq("a", "b"))
114+
case _ =>
115+
fail("Did not match DaysTransform extractor")
116+
}
117+
118+
transform("unknown", ref("a", "b")) match {
119+
case DaysTransform(FieldReference(_)) =>
120+
fail("Matched unknown transform")
121+
case _ =>
122+
// expected
123+
}
124+
}
125+
126+
test("Hours extractor") {
127+
transform("hours", ref("a", "b")) match {
128+
case HoursTransform(FieldReference(seq)) =>
129+
assert(seq === Seq("a", "b"))
130+
case _ =>
131+
fail("Did not match HoursTransform extractor")
132+
}
133+
134+
transform("unknown", ref("a", "b")) match {
135+
case HoursTransform(FieldReference(_)) =>
136+
fail("Matched unknown transform")
137+
case _ =>
138+
// expected
139+
}
140+
}
141+
142+
test("Bucket extractor") {
143+
bucket(16, ref("a", "b")) match {
144+
case BucketTransform(numBuckets, FieldReference(seq)) =>
145+
assert(numBuckets === 16)
146+
assert(seq === Seq("a", "b"))
147+
case _ =>
148+
fail("Did not match BucketTransform extractor")
149+
}
150+
151+
transform("unknown", ref("a", "b")) match {
152+
case BucketTransform(_, _) =>
153+
fail("Matched unknown transform")
154+
case _ =>
155+
// expected
156+
}
157+
}
158+
}

0 commit comments

Comments
 (0)