From 1821443196e91f3f1c97362ffa96e6e9519b5754 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 5 Jun 2019 16:41:04 -0700 Subject: [PATCH 1/2] SPARK-27965: Add extractors for v2 catalog transforms. --- .../catalog/v2/expressions/expressions.scala | 83 +++++++++ .../expressions/TransformExtractorSuite.scala | 158 ++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala index 2d4d6e7c6d5e..ea5fc05dd5ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala @@ -94,6 +94,17 @@ private[sql] final case class BucketTransform( override def toString: String = describe } +private[sql] object BucketTransform { + def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { + case NamedTransform("bucket", Seq( + Lit(value: Int, IntegerType), + Ref(seq: Seq[String]))) => + Some((value, FieldReference(seq))) + case _ => + None + } +} + private[sql] final case class ApplyTransform( name: String, args: Seq[Expression]) extends Transform { @@ -111,32 +122,104 @@ private[sql] final case class ApplyTransform( override def toString: String = describe } +/** + * Convenience extractor for any Literal. + */ +private object Lit { + def unapply[T](literal: Literal[T]): Some[(T, DataType)] = { + Some((literal.value, literal.dataType)) + } +} + +/** + * Convenience extractor for any NamedReference. + */ +private object Ref { + def unapply(named: NamedReference): Some[Seq[String]] = { + Some(named.fieldNames) + } +} + +/** + * Convenience extractor for any Transform. + */ +private object NamedTransform { + def unapply(transform: Transform): Some[(String, Seq[Expression])] = { + Some((transform.name, transform.arguments)) + } +} + private[sql] final case class IdentityTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "identity" override def describe: String = ref.describe } +private[sql] object IdentityTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("identity", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class YearsTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "years" } +private[sql] object YearsTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("years", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class MonthsTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "months" } +private[sql] object MonthsTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("months", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class DaysTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "days" } +private[sql] object DaysTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("days", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class HoursTransform( ref: NamedReference) extends SingleColumnTransform(ref) { override val name: String = "hours" } +private[sql] object HoursTransform { + def unapply(transform: Transform): Option[FieldReference] = transform match { + case NamedTransform("hours", Seq(Ref(parts))) => + Some(FieldReference(parts)) + case _ => + None + } +} + private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { override def describe: String = { if (dataType.isInstanceOf[StringType]) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala new file mode 100644 index 000000000000..fbdaced49d9b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.DataType + +class TransformExtractorSuite extends SparkFunSuite { + /** + * Creates a Literal using an anonymous class. + */ + private def lit[T](literal: T): Literal[T] = new Literal[T] { + override def value: T = literal + override def dataType: DataType = catalyst.expressions.Literal(literal).dataType + override def describe: String = literal.toString + } + + /** + * Creates a NamedReference using an anonymous class. + */ + private def ref(names: String*): NamedReference = new NamedReference { + override def fieldNames: Array[String] = names.toArray + override def describe: String = names.mkString(".") + } + + /** + * Creates a Transform using an anonymous class. + */ + def transform(func: String, ref: NamedReference): Transform = new Transform { + override def name: String = func + override def references: Array[NamedReference] = Array(ref) + override def arguments: Array[Expression] = Array(ref) + override def describe: String = ref.describe + } + + /** + * Creates a bucket Transform using an anonymous class. + */ + def bucket(numBuckets: Int, ref: NamedReference): Transform = new Transform { + override def name: String = "bucket" + override def references: Array[NamedReference] = Array(ref) + override def arguments: Array[Expression] = Array(lit(numBuckets), ref) + override def describe: String = ref.describe + } + + test("Identity extractor") { + transform("identity", ref("a", "b")) match { + case IdentityTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match IdentityTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case IdentityTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Years extractor") { + transform("years", ref("a", "b")) match { + case YearsTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match YearsTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case YearsTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Months extractor") { + transform("months", ref("a", "b")) match { + case MonthsTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match MonthsTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case MonthsTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Days extractor") { + transform("days", ref("a", "b")) match { + case DaysTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match DaysTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case DaysTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Hours extractor") { + transform("hours", ref("a", "b")) match { + case HoursTransform(FieldReference(seq)) => + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match HoursTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case HoursTransform(FieldReference(_)) => + fail("Matched unknown transform") + case _ => + // expected + } + } + + test("Bucket extractor") { + bucket(16, ref("a", "b")) match { + case BucketTransform(numBuckets, FieldReference(seq)) => + assert(numBuckets === 16) + assert(seq === Seq("a", "b")) + case _ => + fail("Did not match BucketTransform extractor") + } + + transform("unknown", ref("a", "b")) match { + case BucketTransform(_, _) => + fail("Matched unknown transform") + case _ => + // expected + } + } +} From 453bb926ba230c486085014b58eb1468ac960baa Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 6 Jun 2019 09:40:29 -0700 Subject: [PATCH 2/2] Update tests for review comments. --- .../expressions/TransformExtractorSuite.scala | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala index fbdaced49d9b..c0a5dada19db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/expressions/TransformExtractorSuite.scala @@ -42,23 +42,13 @@ class TransformExtractorSuite extends SparkFunSuite { /** * Creates a Transform using an anonymous class. */ - def transform(func: String, ref: NamedReference): Transform = new Transform { + private def transform(func: String, ref: NamedReference): Transform = new Transform { override def name: String = func override def references: Array[NamedReference] = Array(ref) override def arguments: Array[Expression] = Array(ref) override def describe: String = ref.describe } - /** - * Creates a bucket Transform using an anonymous class. - */ - def bucket(numBuckets: Int, ref: NamedReference): Transform = new Transform { - override def name: String = "bucket" - override def references: Array[NamedReference] = Array(ref) - override def arguments: Array[Expression] = Array(lit(numBuckets), ref) - override def describe: String = ref.describe - } - test("Identity extractor") { transform("identity", ref("a", "b")) match { case IdentityTransform(FieldReference(seq)) => @@ -71,7 +61,7 @@ class TransformExtractorSuite extends SparkFunSuite { case IdentityTransform(FieldReference(_)) => fail("Matched unknown transform") case _ => - // expected + // expected } } @@ -87,7 +77,7 @@ class TransformExtractorSuite extends SparkFunSuite { case YearsTransform(FieldReference(_)) => fail("Matched unknown transform") case _ => - // expected + // expected } } @@ -103,7 +93,7 @@ class TransformExtractorSuite extends SparkFunSuite { case MonthsTransform(FieldReference(_)) => fail("Matched unknown transform") case _ => - // expected + // expected } } @@ -119,7 +109,7 @@ class TransformExtractorSuite extends SparkFunSuite { case DaysTransform(FieldReference(_)) => fail("Matched unknown transform") case _ => - // expected + // expected } } @@ -140,7 +130,15 @@ class TransformExtractorSuite extends SparkFunSuite { } test("Bucket extractor") { - bucket(16, ref("a", "b")) match { + val col = ref("a", "b") + val bucketTransform = new Transform { + override def name: String = "bucket" + override def references: Array[NamedReference] = Array(col) + override def arguments: Array[Expression] = Array(lit(16), col) + override def describe: String = s"bucket(16, ${col.describe})" + } + + bucketTransform match { case BucketTransform(numBuckets, FieldReference(seq)) => assert(numBuckets === 16) assert(seq === Seq("a", "b"))