Skip to content

Commit 53d7763

Browse files
committed
[SPARK-19826][ML][PYTHON]add spark.ml Python API for PIC
1 parent 1d758dc commit 53d7763

File tree

1 file changed

+191
-1
lines changed

1 file changed

+191
-1
lines changed

python/pyspark/ml/clustering.py

Lines changed: 191 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pyspark import since, keyword_only
2121
from pyspark.ml.util import *
22-
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
22+
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, JavaWrapper
2323
from pyspark.ml.param.shared import *
2424
from pyspark.ml.common import inherit_doc
2525

@@ -1156,6 +1156,196 @@ def getKeepLastCheckpoint(self):
11561156
return self.getOrDefault(self.keepLastCheckpoint)
11571157

11581158

1159+
class _PowerIterationClusteringParams(JavaParams, HasMaxIter, HasPredictionCol):
1160+
"""
1161+
Params for :py:attr:`PowerIterationClustering`.
1162+
.. versionadded:: 2.4.0
1163+
"""
1164+
1165+
k = Param(Params._dummy(), "k",
1166+
"The number of clusters to create. Must be > 1.",
1167+
typeConverter=TypeConverters.toInt)
1168+
initMode = Param(Params._dummy(), "initMode",
1169+
"The initialization algorithm. This can be either " +
1170+
"'random' to use a random vector as vertex properties, or 'degree' to use " +
1171+
"a normalized sum of similarities with other vertices. Supported options: " +
1172+
"'random' and 'degree'.",
1173+
typeConverter=TypeConverters.toString)
1174+
idCol = Param(Params._dummy(), "idCol",
1175+
"Name of the input column for vertex IDs.",
1176+
typeConverter=TypeConverters.toString)
1177+
neighborsCol = Param(Params._dummy(), "neighborsCol",
1178+
"Name of the input column for neighbors in the adjacency list " +
1179+
"representation.",
1180+
typeConverter=TypeConverters.toString)
1181+
similaritiesCol = Param(Params._dummy(), "similaritiesCol",
1182+
"non-negative weights (similarities) of edges between the vertex in " +
1183+
"`idCol` and each neighbor in `neighborsCol`",
1184+
typeConverter=TypeConverters.toString)
1185+
1186+
@since("2.4.0")
1187+
def getK(self):
1188+
"""
1189+
Gets the value of `k`
1190+
"""
1191+
return self.getOrDefault(self.k)
1192+
1193+
@since("2.4.0")
1194+
def getInitMode(self):
1195+
"""
1196+
Gets the value of `initMode`
1197+
"""
1198+
return self.getOrDefault(self.initMode)
1199+
1200+
@since("2.4.0")
1201+
def getIdCol(self):
1202+
"""
1203+
Gets the value of `idCol`
1204+
"""
1205+
return self.getOrDefault(self.idCol)
1206+
1207+
@since("2.4.0")
1208+
def getNeighborsCol(self):
1209+
"""
1210+
Gets the value of `neighborsCol`
1211+
"""
1212+
return self.getOrDefault(self.neighborsCol)
1213+
1214+
@since("2.4.0")
1215+
def getSimilaritiesCol(self):
1216+
"""
1217+
Gets the value of `similaritiesCol`
1218+
"""
1219+
return self.getOrDefault(self.binary)
1220+
1221+
1222+
@inherit_doc
1223+
class PowerIterationClustering(JavaTransformer, _PowerIterationClusteringParams, JavaMLReadable,
1224+
JavaMLWritable):
1225+
"""
1226+
Model produced by [[PowerIterationClustering]].
1227+
>>> from pyspark.sql.types import ArrayType, DoubleType, LongType, StructField, StructType
1228+
>>> import math
1229+
>>> def genCircle(r, n):
1230+
... points = []
1231+
... for i in range(0, n):
1232+
... theta = 2.0 * math.pi * i / n
1233+
... points.append((r * math.cos(theta), r * math.sin(theta)))
1234+
... return points
1235+
>>> def sim(x, y):
1236+
... dist = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1])
1237+
... return math.exp(-dist / 2.0)
1238+
>>> r1 = 1.0
1239+
>>> n1 = 10
1240+
>>> r2 = 4.0
1241+
>>> n2 = 40
1242+
>>> n = n1 + n2
1243+
>>> points = genCircle(r1, n1) + genCircle(r2, n2)
1244+
>>> similarities = []
1245+
>>> for i in range (1, n):
1246+
... neighbor = []
1247+
... weight = []
1248+
... for j in range (i):
1249+
... neighbor.append((long)(j))
1250+
... weight.append(sim(points[i], points[j]))
1251+
... similarities.append([(long)(i), neighbor, weight])
1252+
>>> rdd = sc.parallelize(similarities, 2)
1253+
>>> schema = StructType([StructField("id", LongType(), False), \
1254+
StructField("neighbors", ArrayType(LongType(), False), True), \
1255+
StructField("similarities", ArrayType(DoubleType(), False), True)])
1256+
>>> pic = PowerIterationClustering()
1257+
>>> df = spark.createDataFrame(rdd, schema)
1258+
>>> result = pic.setK(2).setMaxIter(40).transform(df)
1259+
>>> predictions = sorted(set([(i[0], i[1]) for i in result.select(result.id, result.prediction)
1260+
... .collect()]), key=lambda x: x[0])
1261+
>>> predictions[0]
1262+
(1, 1)
1263+
>>> predictions[8]
1264+
(9, 1)
1265+
>>> predictions[9]
1266+
(10, 0)
1267+
>>> predictions[20]
1268+
(21, 0)
1269+
>>> predictions[48]
1270+
(49, 0)
1271+
>>> pic_path = temp_path + "/pic"
1272+
>>> pic.save(pic_path)
1273+
>>> pic2 = PowerIterationClustering.load(pic_path)
1274+
>>> pic2.getK()
1275+
2
1276+
>>> pic2.getMaxIter()
1277+
40
1278+
>>> pic3 = PowerIterationClustering(k=4, initMode="degree")
1279+
>>> pic3.getK()
1280+
4
1281+
>>> pic3.getMaxIter()
1282+
20
1283+
>>> pic3.getInitMode()
1284+
'degree'
1285+
.. versionadded:: 2.4.0
1286+
"""
1287+
@keyword_only
1288+
def __init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",
1289+
idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"):
1290+
"""
1291+
__init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",\
1292+
idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"):
1293+
"""
1294+
super(PowerIterationClustering, self).__init__()
1295+
self._java_obj = self._new_java_obj(
1296+
"org.apache.spark.ml.clustering.PowerIterationClustering", self.uid)
1297+
self._setDefault(k=2, maxIter=20, initMode="random")
1298+
kwargs = self._input_kwargs
1299+
self.setParams(**kwargs)
1300+
1301+
@keyword_only
1302+
@since("2.4.0")
1303+
def setParams(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",
1304+
idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"):
1305+
"""
1306+
setParams(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",\
1307+
idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"):
1308+
Sets params for PowerIterationClustering.
1309+
"""
1310+
kwargs = self._input_kwargs
1311+
return self._set(**kwargs)
1312+
1313+
@since("2.4.0")
1314+
def setK(self, value):
1315+
"""
1316+
Sets the value of :py:attr:`k`.
1317+
"""
1318+
return self._set(k=value)
1319+
1320+
@since("2.4.0")
1321+
def setInitMode(self, value):
1322+
"""
1323+
Sets the value of :py:attr:`initMode`.
1324+
"""
1325+
return self._set(initMode=value)
1326+
1327+
@since("2.4.0")
1328+
def setIdCol(self, value):
1329+
"""
1330+
Sets the value of :py:attr:`idCol`.
1331+
"""
1332+
return self._set(idCol=value)
1333+
1334+
@since("2.4.0")
1335+
def setNeighborsCol(self, value):
1336+
"""
1337+
Sets the value of :py:attr:`neighborsCol.
1338+
"""
1339+
return self._set(neighborsCol=value)
1340+
1341+
@since("2.4.0")
1342+
def setSimilaritiesCol(self, value):
1343+
"""
1344+
Sets the value of :py:attr:`similaritiesCol`.
1345+
"""
1346+
return self._set(similaritiesCol=value)
1347+
1348+
11591349
if __name__ == "__main__":
11601350
import doctest
11611351
import pyspark.ml.clustering

0 commit comments

Comments
 (0)