diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 259be2679ce1..b25fff973c44 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -67,7 +67,9 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) val tableName = Identifiable.randomUID(uid) dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - dataset.sparkSession.sql(realStatement) + val result = dataset.sparkSession.sql(realStatement) + dataset.sparkSession.catalog.dropTempView(tableName) + result } @Since("1.6.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 23464073e6ed..753f890c4830 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -43,6 +43,7 @@ class SQLTransformerSuite assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) } test("read/write") {