Skip to content

Commit 928fb84

Browse files
committed
Maybe done
1 parent a910ac7 commit 928fb84

File tree

5 files changed

+38
-17
lines changed

5 files changed

+38
-17
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.apache.spark.ml.classification.ClassificationModel;
2929
import org.apache.spark.ml.param.IntParam;
3030
import org.apache.spark.ml.param.ParamMap;
31-
import org.apache.spark.ml.param.Params;
3231
import org.apache.spark.ml.param.Params$;
3332
import org.apache.spark.mllib.linalg.BLAS;
3433
import org.apache.spark.mllib.linalg.Vector;
@@ -100,10 +99,12 @@ public static void main(String[] args) throws Exception {
10099
/**
101100
* Example of defining a type of {@link Classifier}.
102101
*
103-
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
102+
* Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to
103+
* {@link org.apache.spark.ml.param.Params#set} using incompatible return types.
104+
* However, this should still compile and run successfully.
104105
*/
105106
class MyJavaLogisticRegression
106-
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> implements Params {
107+
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
107108

108109
/**
109110
* Param for max number of iterations
@@ -144,10 +145,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap)
144145
/**
145146
* Example of defining a type of {@link ClassificationModel}.
146147
*
147-
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
148+
* Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to
149+
* {@link org.apache.spark.ml.param.Params#set} using incompatible return types.
150+
* However, this should still compile and run successfully.
148151
*/
149152
class MyJavaLogisticRegressionModel
150-
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
153+
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
151154

152155
private MyJavaLogisticRegression parent_;
153156
public MyJavaLogisticRegression parent() { return parent_; }

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
4343
extends Serializable {
4444

4545
def this(parent: Params, name: String, doc: String) =
46-
this(parent, name, doc, ParamValidate.default[T])
46+
this(parent, name, doc, ParamValidate.alwaysTrue[T])
4747

4848
/**
4949
* Assert that the given value is valid for this parameter.
@@ -96,8 +96,8 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
9696
*/
9797
object ParamValidate {
9898

99-
/** Default validation always return true */
100-
def default[T]: T => Boolean = (_: T) => true
99+
/** (private[param]) Default validation always return true */
100+
private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true
101101

102102
/**
103103
* Private method for checking numerical types and converting to Double.
@@ -176,7 +176,7 @@ class DoubleParam(parent: Params, name: String, doc: String, isValid: Double =>
176176
extends Param[Double](parent, name, doc, isValid) {
177177

178178
def this(parent: Params, name: String, doc: String) =
179-
this(parent, name, doc, ParamValidate.default[Double])
179+
this(parent, name, doc, ParamValidate.alwaysTrue[Double])
180180

181181
override def w(value: Double): ParamPair[Double] = super.w(value)
182182
}
@@ -186,7 +186,7 @@ class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolea
186186
extends Param[Int](parent, name, doc) {
187187

188188
def this(parent: Params, name: String, doc: String) =
189-
this(parent, name, doc, ParamValidate.default[Int])
189+
this(parent, name, doc, ParamValidate.alwaysTrue[Int])
190190

191191
override def w(value: Int): ParamPair[Int] = super.w(value)
192192
}
@@ -196,7 +196,7 @@ class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Bo
196196
extends Param[Float](parent, name, doc) {
197197

198198
def this(parent: Params, name: String, doc: String) =
199-
this(parent, name, doc, ParamValidate.default[Float])
199+
this(parent, name, doc, ParamValidate.alwaysTrue[Float])
200200

201201
override def w(value: Float): ParamPair[Float] = super.w(value)
202202
}
@@ -206,7 +206,7 @@ class LongParam(parent: Params, name: String, doc: String, isValid: Long => Bool
206206
extends Param[Long](parent, name, doc) {
207207

208208
def this(parent: Params, name: String, doc: String) =
209-
this(parent, name, doc, ParamValidate.default[Long])
209+
this(parent, name, doc, ParamValidate.alwaysTrue[Long])
210210

211211
override def w(value: Long): ParamPair[Long] = super.w(value)
212212
}
@@ -351,11 +351,13 @@ trait Params extends Identifiable with Serializable {
351351

352352
/**
353353
* Sets default values for a list of params.
354+
*
355+
* Note: Java developers should use the single-parameter [[setDefault()]].
356+
*
354357
* @param paramPairs a list of param pairs that specify params and their default values to set
355358
* respectively. Make sure that the params are initialized before this method
356359
* gets called.
357360
*/
358-
@varargs
359361
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
360362
paramPairs.foreach { p =>
361363
setDefault(p.param.asInstanceOf[Param[Any]], p.value)

mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.apache.spark.ml.param;
22

3+
import com.google.common.collect.Lists;
34
import org.junit.After;
45
import org.junit.Assert;
56
import org.junit.Before;
@@ -29,7 +30,21 @@ public void tearDown() {
2930
public void testParams() {
3031
JavaTestParams testParams = new JavaTestParams();
3132
Assert.assertEquals(testParams.getMyIntParam(), 1);
32-
testParams.setMyIntParam(2).setMyDoubleParam(0.4);
33+
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
3334
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
35+
Assert.assertEquals(testParams.getMyStringParam(), "a");
36+
}
37+
38+
@Test
39+
public void testParamValidate() {
40+
ParamValidate.alwaysTrue();
41+
ParamValidate.gt(1.0);
42+
ParamValidate.gtEq(1.0);
43+
ParamValidate.lt(1.0);
44+
ParamValidate.ltEq(1.0);
45+
ParamValidate.inRange(0, 1, true, false);
46+
ParamValidate.inRange(0, 1);
47+
ParamValidate.inArray(Lists.newArrayList(0, 1, 3));
48+
ParamValidate.inArray(Lists.newArrayList("a", "b"));
3449
}
3550
}

mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public JavaTestParams() {
4040
List<String> validStrings = Lists.newArrayList("a", "b");
4141
myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
4242
ParamValidate.inArray(validStrings));
43-
setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
43+
setDefault(myIntParam, 1);
44+
setDefault(myDoubleParam, 0.5);
4445
}
4546
}

mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ class ParamsSuite extends FunSuite {
147147
}
148148

149149
test("ParamValidate") {
150-
val default = ParamValidate.default[Int]
151-
assert(default(1))
150+
val alwaysTrue = ParamValidate.alwaysTrue[Int]
151+
assert(alwaysTrue(1))
152152

153153
val gt1Int = ParamValidate.gt[Int](1)
154154
assert(!gt1Int(1) && gt1Int(2))

0 commit comments

Comments
 (0)