Skip to content

Commit a910ac7

Browse files
committed
still workin
1 parent 6d60e2e commit a910ac7

File tree

6 files changed

+81
-69
lines changed

6 files changed

+81
-69
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.spark.ml.classification.ClassificationModel;
2929
import org.apache.spark.ml.param.IntParam;
3030
import org.apache.spark.ml.param.ParamMap;
31+
import org.apache.spark.ml.param.Params;
3132
import org.apache.spark.ml.param.Params$;
3233
import org.apache.spark.mllib.linalg.BLAS;
3334
import org.apache.spark.mllib.linalg.Vector;
@@ -102,7 +103,7 @@ public static void main(String[] args) throws Exception {
102103
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
103104
*/
104105
class MyJavaLogisticRegression
105-
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
106+
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> implements Params {
106107

107108
/**
108109
* Param for max number of iterations
@@ -146,7 +147,7 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap)
146147
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
147148
*/
148149
class MyJavaLogisticRegressionModel
149-
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
150+
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
150151

151152
private MyJavaLogisticRegression parent_;
152153
public MyJavaLogisticRegression parent() { return parent_; }

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.ml.param
2020
import java.lang.reflect.Modifier
2121
import java.util.NoSuchElementException
2222

23-
import scala.collection.mutable
2423
import scala.annotation.varargs
24+
import scala.collection.mutable
2525

2626
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
2727
import org.apache.spark.ml.util.Identifiable
@@ -90,45 +90,50 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
9090
}
9191
}
9292

93-
/** Factory methods for common validation functions for [[Param.isValid]] */
93+
/**
94+
* Factory methods for common validation functions for [[Param.isValid]].
95+
* The numerical methods only support Int, Long, Float, and Double.
96+
*/
9497
object ParamValidate {
9598

9699
/** Default validation always return true */
97100
def default[T]: T => Boolean = (_: T) => true
98101

99-
/** Negate the given check */
100-
def not[T](isValid: T => Boolean): T => Boolean = { (value: T) =>
101-
!isValid(value)
102+
/**
103+
* Private method for checking numerical types and converting to Double.
104+
* This is mainly for the sake of compilation; type checks are really handled
105+
* by [[Params]] setters and the [[ParamPair]] constructor.
106+
*/
107+
private def getDouble[T](value: T): Double = value match {
108+
case x: Int => x.toDouble
109+
case x: Long => x.toDouble
110+
case x: Float => x.toDouble
111+
case x: Double => x.toDouble
112+
case _ =>
113+
// The type should be checked before this is ever called.
114+
throw new IllegalArgumentException("Numerical Param validation failed because" +
115+
s" of unexpected input type: ${value.getClass}")
102116
}
103117

104-
/** Combine two checks */
105-
def and[T](isValid1: T => Boolean, isValid2: T => Boolean): T => Boolean = { (value: T) =>
106-
isValid1(value) && isValid2(value)
118+
/** Check if value > lowerBound */
119+
def gt[T](lowerBound: Double): T => Boolean = { (value: T) =>
120+
getDouble(value) > lowerBound
107121
}
108122

109-
/** Check for value > lowerBound. Use [[not()]] for <= check. */
110-
def gt(lowerBound: Int): Int => Boolean = { (value: Int) => value > lowerBound }
111-
112-
/** Check for value >= lowerBound. Use [[not()]] for < check. */
113-
def gtEq(lowerBound: Int): Int => Boolean = { (value: Int) => value >= lowerBound }
114-
115-
/** Check for value > lowerBound. Use [[not()]] for <= check. */
116-
def gt(lowerBound: Long): Long => Boolean = { (value: Long) => value > lowerBound }
117-
118-
/** Check for value >= lowerBound. Use [[not()]] for < check. */
119-
def gtEq(lowerBound: Long): Long => Boolean = { (value: Long) => value >= lowerBound }
120-
121-
/** Check for value > lowerBound. Use [[not()]] for <= check. */
122-
def gt(lowerBound: Float): Float => Boolean = { (value: Float) => value > lowerBound }
123-
124-
/** Check for value >= lowerBound. Use [[not()]] for < check. */
125-
def gtEq(lowerBound: Float): Float => Boolean = { (value: Float) => value >= lowerBound }
123+
/** Check if value >= lowerBound */
124+
def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) =>
125+
getDouble(value) >= lowerBound
126+
}
126127

127-
/** Check for value > lowerBound. Use [[not()]] for <= check. */
128-
def gt(lowerBound: Double): Double => Boolean = { (value: Double) => value > lowerBound }
128+
/** Check if value < upperBound */
129+
def lt[T](upperBound: Double): T => Boolean = { (value: T) =>
130+
getDouble(value) < upperBound
131+
}
129132

130-
/** Check for value >= lowerBound. Use [[not()]] for < check. */
131-
def gtEq(lowerBound: Double): Double => Boolean = { (value: Double) => value >= lowerBound }
133+
/** Check if value <= upperBound */
134+
def ltEq[T](upperBound: Double): T => Boolean = { (value: T) =>
135+
getDouble(value) <= upperBound
136+
}
132137

133138
/**
134139
* Check for value in range lowerBound to upperBound.
@@ -137,33 +142,31 @@ object ParamValidate {
137142
* @param upperInclusive If true, check for value <= upperBound.
138143
* If false, check for value < upperBound.
139144
*/
140-
def inRange[T <: Comparable[T]](
141-
lowerBound: T,
142-
upperBound: T,
145+
def inRange[T](
146+
lowerBound: Double,
147+
upperBound: Double,
143148
lowerInclusive: Boolean,
144-
upperInclusive: Boolean): T => Boolean = { (x: T) =>
145-
val lowerValid = if (lowerInclusive) {
146-
x.compareTo(lowerBound) >= 0
147-
} else {
148-
x.compareTo(lowerBound) > 0
149-
}
150-
val upperValid = if (upperInclusive) {
151-
x.compareTo(upperBound) <= 0
152-
} else {
153-
x.compareTo(upperBound) < 0
154-
}
149+
upperInclusive: Boolean): T => Boolean = { (value: T) =>
150+
val x: Double = getDouble(value)
151+
val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound
152+
val upperValid = if (upperInclusive) x <= upperBound else x < upperBound
155153
lowerValid && upperValid
156154
}
157155

158156
/** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */
159-
def inRange[T](lowerBound: T, upperBound: T): T => Boolean = {
157+
def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = {
160158
inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true)
161159
}
162160

163161
/** Check for value in an allowed set of values. */
164162
def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) =>
165163
allowed.contains(value)
166164
}
165+
166+
/** Check for value in an allowed set of values. */
167+
def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
168+
allowed.contains(value)
169+
}
167170
}
168171

169172
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,9 @@ private[shared] object SharedParamsCodeGen {
129129

130130
s"""
131131
|/**
132-
| * :: DeveloperApi ::
133-
| * Trait for shared param $name$defaultValueDoc.
132+
| * (private[ml]) Trait for shared param $name$defaultValueDoc.
134133
| */
135-
|@DeveloperApi
136-
|trait Has$Name extends Params {
134+
|private[ml] trait Has$Name extends Params {
137135
|
138136
| /**
139137
| * Param for $doc.

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
5656
* Default: 10
5757
* @group param
5858
*/
59-
val rank = new IntParam(this, "rank", "rank of the factorization",
60-
isValid = ParamValidate.gtEq[Int](1))
59+
val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidate.gtEq[Int](1))
6160

6261
/** @group getParam */
6362
def getRank: Int = getOrDefault(rank)
@@ -68,7 +67,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
6867
* @group param
6968
*/
7069
val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks",
71-
isValid = ParamValidate.gtEq[Int](1))
70+
ParamValidate.gtEq[Int](1))
7271

7372
/** @group getParam */
7473
def getNumUserBlocks: Int = getOrDefault(numUserBlocks)
@@ -78,9 +77,8 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
7877
* Default: 10
7978
* @group param
8079
*/
81-
val numItemBlocks =
82-
new IntParam(this, "numItemBlocks", "number of item blocks",
83-
isValid = ParamValidate.gtEq[Int](1))
80+
val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks",
81+
ParamValidate.gtEq[Int](1))
8482

8583
/** @group getParam */
8684
def getNumItemBlocks: Int = getOrDefault(numItemBlocks)
@@ -101,7 +99,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
10199
* @group param
102100
*/
103101
val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference",
104-
isValid = ParamValidate.gtEq[Double](0))
102+
ParamValidate.gtEq[Double](0))
105103

106104
/** @group getParam */
107105
def getAlpha: Double = getOrDefault(alpha)

mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,7 @@ public void tearDown() {
2929
public void testParams() {
3030
JavaTestParams testParams = new JavaTestParams();
3131
Assert.assertEquals(testParams.getMyIntParam(), 1);
32+
testParams.setMyIntParam(2).setMyDoubleParam(0.4);
33+
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
3234
}
3335
}
Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package org.apache.spark.ml.param;
22

3-
import org.apache.spark.ml.param.Params;
4-
import org.apache.spark.ml.param.shared.HasMaxIter;
3+
import java.util.List;
4+
5+
import com.google.common.collect.Lists;
56

67
/**
78
* A subclass of Params for testing.
@@ -10,26 +11,35 @@ public class JavaTestParams extends JavaParams {
1011

1112
public IntParam myIntParam;
1213

13-
public DoubleParam myDoubleParam;
14-
1514
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
1615

17-
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
18-
1916
public JavaTestParams setMyIntParam(int value) {
2017
set(myIntParam, value); return this;
2118
}
2219

20+
public DoubleParam myDoubleParam;
21+
22+
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
23+
2324
public JavaTestParams setMyDoubleParam(double value) {
2425
set(myDoubleParam, value); return this;
2526
}
2627

28+
public Param<String> myStringParam;
29+
30+
public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
31+
32+
public JavaTestParams setMyStringParam(String value) {
33+
set(myStringParam, value); return this;
34+
}
35+
2736
public JavaTestParams() {
28-
myIntParam =
29-
new IntParam(this, "myIntParam", "this is an int param", ParamValidate.gt(0));
30-
myDoubleParam =
31-
new DoubleParam(this, "myDoubleParam", "this is a double param",
32-
ParamValidate.and(ParamValidate.gtEq(0.0), ParamValidate.gt(1.0));
33-
setDefault(myIntParam.w(1));
37+
myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidate.gt(0));
38+
myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
39+
ParamValidate.inRange(0.0, 1.0));
40+
List<String> validStrings = Lists.newArrayList("a", "b");
41+
myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
42+
ParamValidate.inArray(validStrings));
43+
setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
3444
}
3545
}

0 commit comments

Comments
 (0)