From bdd8169d71e9dfb6411f43505e499dcfe989cc91 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 10 Feb 2025 10:03:08 +0800 Subject: [PATCH] fix --- common/utils/src/main/resources/error/error-conditions.json | 2 +- .../scala/org/apache/spark/sql/connect/ml/MLException.scala | 4 ++-- .../main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala | 2 +- .../test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala | 5 ++++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 68ae83867f35..b9a8012275d2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -772,7 +772,7 @@ "subClass" : { "ATTRIBUTE_NOT_ALLOWED" : { "message" : [ - " is not allowed to be accessed." + " in is not allowed to be accessed." ] }, "UNSUPPORTED_EXCEPTION" : { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala index eb88bf9169d3..7700eccf6553 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala @@ -25,8 +25,8 @@ private[spark] case class MlUnsupportedException(message: String) messageParameters = Map("message" -> message), cause = null) -private[spark] case class MLAttributeNotAllowedException(attribute: String) +private[spark] case class MLAttributeNotAllowedException(className: String, attribute: String) extends SparkException( errorClass = "CONNECT_ML.ATTRIBUTE_NOT_ALLOWED", - messageParameters = Map("attribute" -> attribute), + messageParameters = Map("className" -> className, "attribute" -> attribute), cause = null) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index 84a26d9e4962..c999772b7d82 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -674,7 +674,7 @@ private[ml] object MLUtils { cls.isInstance(obj) && methods.contains(method) } if (!valid) { - throw MLAttributeNotAllowedException(method) + throw MLAttributeNotAllowedException(obj.getClass.getName, method) } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index c3ab6248be8f..cc24a2a67439 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -261,9 +261,12 @@ class MLSuite extends MLHelper { val modelId = trainLogisticRegressionModel(sessionHolder) val fakeAttributeCmd = fetchCommand(modelId, "notExistingAttribute") - intercept[MLAttributeNotAllowedException] { + val e = intercept[MLAttributeNotAllowedException] { MLHandler.handleMlCommand(sessionHolder, fakeAttributeCmd) } + val msg = e.getMessage + assert(msg.contains("notExistingAttribute")) + assert(msg.contains("org.apache.spark.ml.classification.LogisticRegressionModel")) } test("Model must be registered into ServiceLoader when loading") {