Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static void main(String[] args) {
ParamMap paramMap = new ParamMap();
paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
double[] thresholds = {0.45, 0.55};
double[] thresholds = {0.5, 0.5};
paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params.

// One can also combine ParamMaps.
Expand Down
24 changes: 11 additions & 13 deletions examples/src/main/python/ml/simple_params_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import pprint
import sys

from pyspark import SparkContext
from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.linalg import DenseVector
from pyspark.mllib.regression import LabeledPoint
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession

"""
A simple example demonstrating ways to specify parameters for Estimators and Transformers.
Expand All @@ -33,21 +32,20 @@
"""

if __name__ == "__main__":
if len(sys.argv) > 1:
print("Usage: simple_params_example", file=sys.stderr)
exit(1)
sc = SparkContext(appName="PythonSimpleParamsExample")
sqlContext = SQLContext(sc)
spark = SparkSession \
.builder \
.appName("SimpleTextClassificationPipeline") \
.getOrCreate()

# prepare training data.
# We create an RDD of LabeledPoints and convert them into a DataFrame.
# A LabeledPoint is an Object with two fields named label and features
# and Spark SQL identifies these fields and creates the schema appropriately.
training = sc.parallelize([
training = spark.createDataFrame([
LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])),
LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF()
LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))])

# Create a LogisticRegression instance with maxIter = 10.
# This instance is an Estimator.
Expand All @@ -70,18 +68,18 @@

# We may alternatively specify parameters using a parameter map.
# paramMap overrides all lr parameters set earlier.
paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"}
paramMap = {lr.maxIter: 20, lr.thresholds: [0.5, 0.5], lr.probabilityCol: "myProbability"}
Copy link
Contributor

@yanboliang yanboliang May 16, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it throws exception when we make predictions because we want to find an authoritative threshold. This change is okey. Actually we use threshold more frequently than thresholds in LogisticRegression, because LR does not support multi classification currently. The community is try to find a way to harmonize the two param for LR, but did not find a final solution. You can refer SPARK-11834 and SPARK-11543 .


# Now learn a new model using the new parameters.
model2 = lr.fit(training, paramMap)
print("Model 2 was fit using parameters:\n")
pprint.pprint(model2.extractParamMap())

# prepare test data.
test = sc.parallelize([
test = spark.createDataFrame([
LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])),
LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])),
LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF()
LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))])

# Make predictions on test data using the Transformer.transform() method.
# LogisticRegressionModel.transform will only use the 'features' column.
Expand All @@ -95,4 +93,4 @@
print("features=%s,label=%s -> prob=%s, prediction=%s"
% (row.features, row.label, row.myProbability, row.prediction))

sc.stop()
spark.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object SimpleParamsExample {
// which supports several methods for specifying parameters.
val paramMap = ParamMap(lr.maxIter -> 20)
paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params.
paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.5, 0.5)) // Specify multiple Params.

// One can also combine ParamMaps.
val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
Expand Down