-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-42412][WIP] Initial PR of Spark connect ML #40297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
connector/connect/common/src/main/protobuf/spark/connect/ml.proto
Outdated
Show resolved
Hide resolved
connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala
Outdated
Show resolved
Hide resolved
Signed-off-by: Weichen Xu <[email protected]>
connector/connect/common/src/main/protobuf/spark/connect/relations.proto
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
Outdated
Show resolved
Hide resolved
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
grundprinzip
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round of reviews on the protos.
| MlEvaluator evaluator = 1; | ||
| } | ||
|
|
||
| message LoadModel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this work with arbitrary model for example provided by Spark NLP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For current PR, it does not support third-party estimators.
We need to register related class for 3rd-party algorithm to AlgorithmRegistry class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to support 3rd-party algorithm without registry, then inevitably we have to use java reflection to invoke methods (e.g. We need to invoke XXXModel.load to load model, which is unsafe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, supporting 3rd-party estimators is risky, because in shared cluster we will binpack the spark workers across different customers (according to @mengxr 's explanation)
But 3rd-party estimators implementation might invoke RDD transformation (e.g. RDD.map) that we cannot isolate them by container. So it is risky if we allow user uses 3rd-party estimators on shared cluster.
| MlParams params = 2; | ||
| string uid = 3; | ||
| StageType type = 4; | ||
| enum StageType { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this knowledge actually required on the client?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we can make server side infer the stage type from stage name,
but let client fill the stage type is easier for code.
connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto
Show resolved
Hide resolved
| } | ||
| message ModelTransform { | ||
| Relation input = 1; | ||
| int64 model_ref_id = 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion here is to maybe wrap the moddel_ref_id into an extra message object that becomes easier to extend.
message ModelRef {
int64 id = 1;
}
That said, is there a reason the ID is numeric vs a string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ID is generated from a increamental counter. So I think int64 type should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
message ModelRef {
int64 id = 1;
}
This sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ID is generated from a increamental counter.
Using random UUID might be a better idea , if we want to support server failover in future (we need to persist status and restore it, random UUID can help avoiding reusing ID that is generated before.)
Signed-off-by: Weichen Xu <[email protected]>
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
Outdated
Show resolved
Hide resolved
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
Show resolved
Hide resolved
Signed-off-by: Weichen Xu <[email protected]>
| return remote_cls | ||
|
|
||
|
|
||
| def try_remote_ml_class(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we can also simplify the pyspark.sql side by only using this annotation to the a few key classes
cc @HyukjinKwon
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
| private[spark] def _setDefault(paramPairs: ParamPair[_]*): this.type = { | ||
| setDefault(paramPairs: _*) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simply change setDefault to protected[spark] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simply change setDefault to protected[spark] ?
This should be a breaking change.
Some 3rd-party estimator might override this method, if they are not under "org.apache" package, then compiling will fail.
|
|
||
| from abc import ABCMeta, abstractmethod | ||
|
|
||
| import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this import needed?
|
|
||
| @classmethod | ||
| def getActiveSession(cls) -> Any: | ||
| raise NotImplementedError("getActiveSession() is not implemented.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this change ? I thought we can use the newly added getOrCreate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh. I will revert this.
| UNSPECIFIED = 0; | ||
| ESTIMATOR = 1; | ||
| TRANSFORMER = 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we normally name enums like this
STAGE_TYPE_UNSPECIFIED = 0;
STAGE_TYPE_ESTIMATOR = 1;
STAGE_TYPE_TRANSFORMER = 2;
| globs = pyspark.sql.connect.dataframe.__dict__.copy() | ||
|
|
||
| globs["spark"] = ( | ||
| PySparkSession.builder.appName("sql.connect.ml.classification tests") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| PySparkSession.builder.appName("sql.connect.ml.classification tests") | |
| PySparkSession.builder.appName("ml.connect.classification tests") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doctest should be added in sparktestsupport/modules.py
| @@ -0,0 +1,61 @@ | |||
| /* | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will we move these ml files to connector/connect/server/src/main/scala/org/apache/spark/ml/connect ?
| } | ||
| } | ||
|
|
||
| class LogisticRegressionAlgorithm extends Algorithm { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can use java reflection to invoke methods, we don't need the registry class, we just need some configuration data for registry.
If we plan to mandatorily enable spark connect mode since spark 4 for DBR, then we'd better use java reflection invocation. Otherwise it is hard to support huge number of 3rd-party estimators.
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
|
We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable. |
What changes were proposed in this pull request?
Design doc:
https://docs.google.com/document/d/1V5rOgksmOnA8AsJFZ_rasSYDQuP06_vrcfp3RY_22o8/edit#
Why are the changes needed?
Does this PR introduce any user-facing change?
How was this patch tested?
Testing code:
run command
bin/pyspark --remote local, in python REPL, run following code: