Skip to content

Commit aa92e84

Browse files
committed
Update data type tests.
1 parent 8da1a17 commit aa92e84

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,19 +295,31 @@ object StructType {
295295
case class StructType(fields: Seq[StructField]) extends DataType {
296296
require(StructType.validateFields(fields), "Found fields with the same name.")
297297

298+
/**
299+
* Returns all field names in a [[Seq]].
300+
*/
301+
lazy val fieldNames: Seq[String] = fields.map(_.name)
302+
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
303+
298304
/**
299305
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
300306
* have a name matching the given name, `null` will be returned.
301307
*/
302308
def apply(name: String): StructField = {
303-
fields.find(f => f.name == name).orNull
309+
fields.find(f => f.name == name).getOrElse(
310+
throw new IllegalArgumentException(s"Field ${name} does not exist."))
304311
}
305312

306313
/**
307314
* Returns a [[StructType]] containing [[StructField]]s of the given names.
308315
* Those names which do not have matching fields will be ignored.
309316
*/
310317
def apply(names: Set[String]): StructType = {
318+
val nonExistFields = names -- fieldNamesSet
319+
if (!nonExistFields.isEmpty) {
320+
throw new IllegalArgumentException(
321+
s"Field ${nonExistFields.mkString(",")} does not exist.")
322+
}
311323
StructType(fields.filter(f => names.contains(f.name)))
312324
}
313325

sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala renamed to sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ package org.apache.spark.sql
1919

2020
import org.scalatest.FunSuite
2121

22-
class SchemaSuite extends FunSuite {
22+
class DataTypeSuite extends FunSuite {
2323

24-
test("constructing an ArrayType") {
24+
test("construct an ArrayType") {
2525
val array = ArrayType(StringType)
2626

2727
assert(ArrayType(StringType, false) === array)
2828
}
2929

30-
test("extracting fields from a StructType") {
30+
test("extract fields from a StructType") {
3131
val struct = StructType(
3232
StructField("a", IntegerType, true) ::
3333
StructField("b", LongType, false) ::
@@ -36,14 +36,17 @@ class SchemaSuite extends FunSuite {
3636

3737
assert(StructField("b", LongType, false) === struct("b"))
3838

39-
assert(struct("e") === null)
39+
intercept[IllegalArgumentException] {
40+
struct("e")
41+
}
4042

4143
val expectedStruct = StructType(
4244
StructField("b", LongType, false) ::
4345
StructField("d", FloatType, true) :: Nil)
4446

4547
assert(expectedStruct === struct(Set("b", "d")))
46-
// struct does not have a field called e. So e is ignored.
47-
assert(expectedStruct === struct(Set("b", "d", "e")))
48+
intercept[IllegalArgumentException] {
49+
struct(Set("b", "d", "e", "f"))
50+
}
4851
}
4952
}

0 commit comments

Comments
 (0)