Skip to content

Commit 0c635d4

Browse files
committed
make match more specific
1 parent 9df388d commit 0c635d4

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.{SaveMode, AnalysisException}
2222
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog}
2323
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias}
2424
import org.apache.spark.sql.catalyst.plans.logical
25-
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
25+
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.types.DataType
2828

@@ -121,8 +121,11 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
121121
failAnalysis(s"$l does not allow insertion.")
122122

123123
case logical.InsertIntoTable(t, _, _, _, _) =>
124-
failAnalysis(
125-
s"Attempt to insert into a RDD-based table: ${t.simpleString} which is immutable.")
124+
if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) {
125+
failAnalysis(s"Inserting into an RDD-based table is not allowed.")
126+
} else {
127+
// OK
128+
}
126129

127130
case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) =>
128131
// When the SaveMode is Overwrite, we need to check if the table is an input table of

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
21+
2022
import scala.language.postfixOps
2123

2224
import org.apache.spark.sql.functions._
@@ -762,7 +764,7 @@ class DataFrameSuite extends QueryTest {
762764
}
763765

764766
test("SPARK-6941: Better error message for inserting into RDD-based Table") {
765-
val df = Seq(Tuple1(1)).toDF("col")
767+
val df = Seq(Tuple1(1)).toDF()
766768
val insertion = Seq(Tuple1(2)).toDF("col")
767769

768770
// pass case: parquet table (HadoopFsRelation)
@@ -782,14 +784,21 @@ class DataFrameSuite extends QueryTest {
782784
val e1 = intercept[AnalysisException] {
783785
insertion.write.insertInto("rdd_base")
784786
}
785-
assert(e1.getMessage.contains("Attempt to insert into a RDD-based table"))
787+
assert(e1.getMessage.contains("Inserting into an RDD-based table is not allowed."))
786788

787789
// error case: insert into a RDD based on data source
788-
val indirectDS = pdf.select("col").filter($"col" > 5)
790+
val indirectDS = pdf.select("_1").filter($"_1" > 5)
789791
indirectDS.registerTempTable("indirect_ds")
790792
val e2 = intercept[AnalysisException] {
791793
insertion.write.insertInto("indirect_ds")
792794
}
793-
assert(e2.getMessage.contains("Attempt to insert into a RDD-based table"))
795+
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
796+
797+
// error case: insert into a OneRowRelation
798+
new DataFrame(ctx, OneRowRelation).registerTempTable("one_row")
799+
val e3 = intercept[AnalysisException] {
800+
insertion.write.insertInto("one_row")
801+
}
802+
assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed."))
794803
}
795804
}

0 commit comments

Comments
 (0)