Skip to content

Commit 2d0e394

Browse files
committed
fix test problem
1 parent 53d7763 commit 2d0e394

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,15 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
9797
def getNeighborsCol: String = $(neighborsCol)
9898

9999
/**
100-
* Param for the name of the input column for neighbors in the adjacency list representation.
100+
* Param for the name of the input column for non-negative weights (similarities) of edges
101+
* between the vertex in `idCol` and each neighbor in `neighborsCol`.
101102
* Default: "similarities"
102103
* @group param
103104
*/
104105
@Since("2.4.0")
105106
val similaritiesCol = new Param[String](this, "similaritiesCol",
106-
"Name of the input column for neighbors in the adjacency list representation.",
107+
"Name of the input column for non-negative weights (similarities) of edges between the " +
108+
"vertex in `idCol` and each neighbor in `neighborsCol`.",
107109
(value: String) => value.nonEmpty)
108110

109111
setDefault(similaritiesCol, "similarities")

python/pyspark/ml/clustering.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,8 +1179,9 @@ class _PowerIterationClusteringParams(JavaParams, HasMaxIter, HasPredictionCol):
11791179
"representation.",
11801180
typeConverter=TypeConverters.toString)
11811181
similaritiesCol = Param(Params._dummy(), "similaritiesCol",
1182-
"non-negative weights (similarities) of edges between the vertex in " +
1183-
"`idCol` and each neighbor in `neighborsCol`",
1182+
"Name of the input column for non-negative weights (similarities) " +
1183+
"of edges between the vertex in `idCol` and each neighbor in " +
1184+
"`neighborsCol`",
11841185
typeConverter=TypeConverters.toString)
11851186

11861187
@since("2.4.0")
@@ -1253,8 +1254,8 @@ class PowerIterationClustering(JavaTransformer, _PowerIterationClusteringParams,
12531254
>>> schema = StructType([StructField("id", LongType(), False), \
12541255
StructField("neighbors", ArrayType(LongType(), False), True), \
12551256
StructField("similarities", ArrayType(DoubleType(), False), True)])
1256-
>>> pic = PowerIterationClustering()
12571257
>>> df = spark.createDataFrame(rdd, schema)
1258+
>>> pic = PowerIterationClustering()
12581259
>>> result = pic.setK(2).setMaxIter(40).transform(df)
12591260
>>> predictions = sorted(set([(i[0], i[1]) for i in result.select(result.id, result.prediction)
12601261
... .collect()]), key=lambda x: x[0])
@@ -1276,12 +1277,16 @@ class PowerIterationClustering(JavaTransformer, _PowerIterationClusteringParams,
12761277
>>> pic2.getMaxIter()
12771278
40
12781279
>>> pic3 = PowerIterationClustering(k=4, initMode="degree")
1280+
>>> pic3.getIdCol()
1281+
'id'
12791282
>>> pic3.getK()
12801283
4
12811284
>>> pic3.getMaxIter()
12821285
20
12831286
>>> pic3.getInitMode()
12841287
'degree'
1288+
1289+
12851290
.. versionadded:: 2.4.0
12861291
"""
12871292
@keyword_only
@@ -1294,7 +1299,8 @@ def __init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random
12941299
super(PowerIterationClustering, self).__init__()
12951300
self._java_obj = self._new_java_obj(
12961301
"org.apache.spark.ml.clustering.PowerIterationClustering", self.uid)
1297-
self._setDefault(k=2, maxIter=20, initMode="random")
1302+
self._setDefault(k=2, maxIter=20, initMode="random", idCol="id", neighborsCol="neighbors",
1303+
similaritiesCol="similarities")
12981304
kwargs = self._input_kwargs
12991305
self.setParams(**kwargs)
13001306

0 commit comments

Comments
 (0)