Skip to content

Commit 58802e3

Browse files
committed
added train() to Predictor subclasses which does not take a ParamMap.
1 parent 57d54ab commit 58802e3

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
7070

7171
/**
7272
* Same as [[fit()]], but using strong types.
73-
*
74-
* @param dataset Training data. WARNING: This does not yet handle instance weights.
73+
* NOTE: This does NOT support instance weights.
74+
* @param dataset Training data. Instance weights are ignored.
7575
* @param paramMap Parameters for training.
7676
* These values override any specified in this Estimator's embedded ParamMap.
7777
*/
78-
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LogisticRegressionModel = {
78+
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LogisticRegressionModel = {
7979
val map = this.paramMap ++ paramMap
8080
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
8181
org.apache.spark.mllib.regression.LabeledPoint(label, features)
@@ -96,6 +96,13 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
9696
}
9797
lrm
9898
}
99+
100+
/**
101+
* Same as [[fit()]], but using strong types.
102+
* NOTE: This does NOT support instance weights.
103+
* @param dataset Training data. Instance weights are ignored.
104+
*/
105+
def train(dataset: RDD[LabeledPoint]): LogisticRegressionModel = train(dataset, new ParamMap())
99106
}
100107

101108

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
4747

4848
/**
4949
* Same as [[fit()]], but using strong types.
50-
*
51-
* @param dataset Training data. WARNING: This does not yet handle instance weights.
50+
* NOTE: This does NOT support instance weights.
51+
* @param dataset Training data. Instance weights are ignored.
5252
* @param paramMap Parameters for training.
5353
* These values override any specified in this Estimator's embedded ParamMap.
5454
*/
55-
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
55+
override def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LinearRegressionModel = {
5656
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
5757
org.apache.spark.mllib.regression.LabeledPoint(label, features)
5858
}
@@ -71,6 +71,13 @@ class LinearRegression extends Regressor[LinearRegression, LinearRegressionModel
7171
}
7272
lrm
7373
}
74+
75+
/**
76+
* Same as [[fit()]], but using strong types.
77+
* NOTE: This does NOT support instance weights.
78+
* @param dataset Training data. Instance weights are ignored.
79+
*/
80+
def train(dataset: RDD[LabeledPoint]): LinearRegressionModel = train(dataset, new ParamMap())
7481
}
7582

7683
/**

0 commit comments

Comments
 (0)