Skip to content

Commit b24c12d

Browse files
yanboliangjkbradley
authored andcommitted
[MINOR][ML] Rename weights to coefficients for examples/DeveloperApiExample
Rename ```weights``` to ```coefficients``` for examples/DeveloperApiExample. cc mengxr jkbradley Author: Yanbo Liang <[email protected]> Closes #10280 from yanboliang/spark-coefficients.
1 parent bc1ff9f commit b24c12d

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public static void main(String[] args) throws Exception {
8989
}
9090
if (sumPredictions != 0.0) {
9191
throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
92-
" even though all weights are 0!");
92+
" even though all coefficients are 0!");
9393
}
9494

9595
jsc.stop();
@@ -149,12 +149,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
149149
// Extract columns from data using helper method.
150150
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
151151

152-
// Do learning to estimate the weight vector.
152+
// Do learning to estimate the coefficients vector.
153153
int numFeatures = oldDataset.take(1).get(0).features().size();
154-
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
154+
Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here.
155155

156156
// Create a model, and return it.
157-
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
157+
return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this);
158158
}
159159

160160
@Override
@@ -173,12 +173,12 @@ public MyJavaLogisticRegression copy(ParamMap extra) {
173173
class MyJavaLogisticRegressionModel
174174
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
175175

176-
private Vector weights_;
177-
public Vector weights() { return weights_; }
176+
private Vector coefficients_;
177+
public Vector coefficients() { return coefficients_; }
178178

179-
public MyJavaLogisticRegressionModel(String uid, Vector weights) {
179+
public MyJavaLogisticRegressionModel(String uid, Vector coefficients) {
180180
this.uid_ = uid;
181-
this.weights_ = weights;
181+
this.coefficients_ = coefficients;
182182
}
183183

184184
private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
@@ -208,7 +208,7 @@ public String uid() {
208208
* modifier.
209209
*/
210210
public Vector predictRaw(Vector features) {
211-
double margin = BLAS.dot(features, weights_);
211+
double margin = BLAS.dot(features, coefficients_);
212212
// There are 2 classes (binary classification), so we return a length-2 vector,
213213
// where index i corresponds to class i (i = 0, 1).
214214
return Vectors.dense(-margin, margin);
@@ -222,7 +222,7 @@ public Vector predictRaw(Vector features) {
222222
/**
223223
* Number of features the model was trained on.
224224
*/
225-
public int numFeatures() { return weights_.size(); }
225+
public int numFeatures() { return coefficients_.size(); }
226226

227227
/**
228228
* Create a copy of the model.
@@ -235,7 +235,7 @@ public Vector predictRaw(Vector features) {
235235
*/
236236
@Override
237237
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
238-
return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra)
238+
return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra)
239239
.setParent(parent());
240240
}
241241
}

examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ object DeveloperApiExample {
7575
prediction
7676
}.sum
7777
assert(sumPredictions == 0.0,
78-
"MyLogisticRegression predicted something other than 0, even though all weights are 0!")
78+
"MyLogisticRegression predicted something other than 0, even though all coefficients are 0!")
7979

8080
sc.stop()
8181
}
@@ -124,12 +124,12 @@ private class MyLogisticRegression(override val uid: String)
124124
// Extract columns from data using helper method.
125125
val oldDataset = extractLabeledPoints(dataset)
126126

127-
// Do learning to estimate the weight vector.
127+
// Do learning to estimate the coefficients vector.
128128
val numFeatures = oldDataset.take(1)(0).features.size
129-
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
129+
val coefficients = Vectors.zeros(numFeatures) // Learning would happen here.
130130

131131
// Create a model, and return it.
132-
new MyLogisticRegressionModel(uid, weights).setParent(this)
132+
new MyLogisticRegressionModel(uid, coefficients).setParent(this)
133133
}
134134

135135
override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
@@ -142,7 +142,7 @@ private class MyLogisticRegression(override val uid: String)
142142
*/
143143
private class MyLogisticRegressionModel(
144144
override val uid: String,
145-
val weights: Vector)
145+
val coefficients: Vector)
146146
extends ClassificationModel[Vector, MyLogisticRegressionModel]
147147
with MyLogisticRegressionParams {
148148

@@ -163,7 +163,7 @@ private class MyLogisticRegressionModel(
163163
* confidence for that label.
164164
*/
165165
override protected def predictRaw(features: Vector): Vector = {
166-
val margin = BLAS.dot(features, weights)
166+
val margin = BLAS.dot(features, coefficients)
167167
// There are 2 classes (binary classification), so we return a length-2 vector,
168168
// where index i corresponds to class i (i = 0, 1).
169169
Vectors.dense(-margin, margin)
@@ -173,7 +173,7 @@ private class MyLogisticRegressionModel(
173173
override val numClasses: Int = 2
174174

175175
/** Number of features the model was trained on. */
176-
override val numFeatures: Int = weights.size
176+
override val numFeatures: Int = coefficients.size
177177

178178
/**
179179
* Create a copy of the model.
@@ -182,7 +182,7 @@ private class MyLogisticRegressionModel(
182182
* This is used for the default implementation of [[transform()]].
183183
*/
184184
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
185-
copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent)
185+
copyValues(new MyLogisticRegressionModel(uid, coefficients), extra).setParent(parent)
186186
}
187187
}
188188
// scalastyle:on println

0 commit comments

Comments
 (0)