|
19 | 19 |
|
20 | 20 | from pyspark import since, keyword_only |
21 | 21 | from pyspark.ml.util import * |
22 | | -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper |
| 22 | +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, JavaWrapper |
23 | 23 | from pyspark.ml.param.shared import * |
24 | 24 | from pyspark.ml.common import inherit_doc |
25 | 25 |
|
@@ -1156,6 +1156,196 @@ def getKeepLastCheckpoint(self): |
1156 | 1156 | return self.getOrDefault(self.keepLastCheckpoint) |
1157 | 1157 |
|
1158 | 1158 |
|
| 1159 | +class _PowerIterationClusteringParams(JavaParams, HasMaxIter, HasPredictionCol): |
| 1160 | + """ |
| 1161 | + Params for :py:attr:`PowerIterationClustering`. |
| 1162 | + .. versionadded:: 2.4.0 |
| 1163 | + """ |
| 1164 | + |
| 1165 | + k = Param(Params._dummy(), "k", |
| 1166 | + "The number of clusters to create. Must be > 1.", |
| 1167 | + typeConverter=TypeConverters.toInt) |
| 1168 | + initMode = Param(Params._dummy(), "initMode", |
| 1169 | + "The initialization algorithm. This can be either " + |
| 1170 | + "'random' to use a random vector as vertex properties, or 'degree' to use " + |
| 1171 | + "a normalized sum of similarities with other vertices. Supported options: " + |
| 1172 | + "'random' and 'degree'.", |
| 1173 | + typeConverter=TypeConverters.toString) |
| 1174 | + idCol = Param(Params._dummy(), "idCol", |
| 1175 | + "Name of the input column for vertex IDs.", |
| 1176 | + typeConverter=TypeConverters.toString) |
| 1177 | + neighborsCol = Param(Params._dummy(), "neighborsCol", |
| 1178 | + "Name of the input column for neighbors in the adjacency list " + |
| 1179 | + "representation.", |
| 1180 | + typeConverter=TypeConverters.toString) |
| 1181 | + similaritiesCol = Param(Params._dummy(), "similaritiesCol", |
| 1182 | + "non-negative weights (similarities) of edges between the vertex in " + |
| 1183 | + "`idCol` and each neighbor in `neighborsCol`", |
| 1184 | + typeConverter=TypeConverters.toString) |
| 1185 | + |
| 1186 | + @since("2.4.0") |
| 1187 | + def getK(self): |
| 1188 | + """ |
| 1189 | + Gets the value of `k` |
| 1190 | + """ |
| 1191 | + return self.getOrDefault(self.k) |
| 1192 | + |
| 1193 | + @since("2.4.0") |
| 1194 | + def getInitMode(self): |
| 1195 | + """ |
| 1196 | + Gets the value of `initMode` |
| 1197 | + """ |
| 1198 | + return self.getOrDefault(self.initMode) |
| 1199 | + |
| 1200 | + @since("2.4.0") |
| 1201 | + def getIdCol(self): |
| 1202 | + """ |
| 1203 | + Gets the value of `idCol` |
| 1204 | + """ |
| 1205 | + return self.getOrDefault(self.idCol) |
| 1206 | + |
| 1207 | + @since("2.4.0") |
| 1208 | + def getNeighborsCol(self): |
| 1209 | + """ |
| 1210 | + Gets the value of `neighborsCol` |
| 1211 | + """ |
| 1212 | + return self.getOrDefault(self.neighborsCol) |
| 1213 | + |
| 1214 | + @since("2.4.0") |
| 1215 | + def getSimilaritiesCol(self): |
| 1216 | + """ |
| 1217 | + Gets the value of `similaritiesCol` |
| 1218 | + """ |
| 1219 | + return self.getOrDefault(self.binary) |
| 1220 | + |
| 1221 | + |
| 1222 | +@inherit_doc |
| 1223 | +class PowerIterationClustering(JavaTransformer, _PowerIterationClusteringParams, JavaMLReadable, |
| 1224 | + JavaMLWritable): |
| 1225 | + """ |
| 1226 | + Model produced by [[PowerIterationClustering]]. |
| 1227 | + >>> from pyspark.sql.types import ArrayType, DoubleType, LongType, StructField, StructType |
| 1228 | + >>> import math |
| 1229 | + >>> def genCircle(r, n): |
| 1230 | + ... points = [] |
| 1231 | + ... for i in range(0, n): |
| 1232 | + ... theta = 2.0 * math.pi * i / n |
| 1233 | + ... points.append((r * math.cos(theta), r * math.sin(theta))) |
| 1234 | + ... return points |
| 1235 | + >>> def sim(x, y): |
| 1236 | + ... dist = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1]) |
| 1237 | + ... return math.exp(-dist / 2.0) |
| 1238 | + >>> r1 = 1.0 |
| 1239 | + >>> n1 = 10 |
| 1240 | + >>> r2 = 4.0 |
| 1241 | + >>> n2 = 40 |
| 1242 | + >>> n = n1 + n2 |
| 1243 | + >>> points = genCircle(r1, n1) + genCircle(r2, n2) |
| 1244 | + >>> similarities = [] |
| 1245 | + >>> for i in range (1, n): |
| 1246 | + ... neighbor = [] |
| 1247 | + ... weight = [] |
| 1248 | + ... for j in range (i): |
| 1249 | + ... neighbor.append((long)(j)) |
| 1250 | + ... weight.append(sim(points[i], points[j])) |
| 1251 | + ... similarities.append([(long)(i), neighbor, weight]) |
| 1252 | + >>> rdd = sc.parallelize(similarities, 2) |
| 1253 | + >>> schema = StructType([StructField("id", LongType(), False), \ |
| 1254 | + StructField("neighbors", ArrayType(LongType(), False), True), \ |
| 1255 | + StructField("similarities", ArrayType(DoubleType(), False), True)]) |
| 1256 | + >>> pic = PowerIterationClustering() |
| 1257 | + >>> df = spark.createDataFrame(rdd, schema) |
| 1258 | + >>> result = pic.setK(2).setMaxIter(40).transform(df) |
| 1259 | + >>> predictions = sorted(set([(i[0], i[1]) for i in result.select(result.id, result.prediction) |
| 1260 | + ... .collect()]), key=lambda x: x[0]) |
| 1261 | + >>> predictions[0] |
| 1262 | + (1, 1) |
| 1263 | + >>> predictions[8] |
| 1264 | + (9, 1) |
| 1265 | + >>> predictions[9] |
| 1266 | + (10, 0) |
| 1267 | + >>> predictions[20] |
| 1268 | + (21, 0) |
| 1269 | + >>> predictions[48] |
| 1270 | + (49, 0) |
| 1271 | + >>> pic_path = temp_path + "/pic" |
| 1272 | + >>> pic.save(pic_path) |
| 1273 | + >>> pic2 = PowerIterationClustering.load(pic_path) |
| 1274 | + >>> pic2.getK() |
| 1275 | + 2 |
| 1276 | + >>> pic2.getMaxIter() |
| 1277 | + 40 |
| 1278 | + >>> pic3 = PowerIterationClustering(k=4, initMode="degree") |
| 1279 | + >>> pic3.getK() |
| 1280 | + 4 |
| 1281 | + >>> pic3.getMaxIter() |
| 1282 | + 20 |
| 1283 | + >>> pic3.getInitMode() |
| 1284 | + 'degree' |
| 1285 | + .. versionadded:: 2.4.0 |
| 1286 | + """ |
| 1287 | + @keyword_only |
| 1288 | + def __init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random", |
| 1289 | + idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): |
| 1290 | + """ |
| 1291 | + __init__(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",\ |
| 1292 | + idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): |
| 1293 | + """ |
| 1294 | + super(PowerIterationClustering, self).__init__() |
| 1295 | + self._java_obj = self._new_java_obj( |
| 1296 | + "org.apache.spark.ml.clustering.PowerIterationClustering", self.uid) |
| 1297 | + self._setDefault(k=2, maxIter=20, initMode="random") |
| 1298 | + kwargs = self._input_kwargs |
| 1299 | + self.setParams(**kwargs) |
| 1300 | + |
| 1301 | + @keyword_only |
| 1302 | + @since("2.4.0") |
| 1303 | + def setParams(self, predictionCol="prediction", k=2, maxIter=20, initMode="random", |
| 1304 | + idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): |
| 1305 | + """ |
| 1306 | + setParams(self, predictionCol="prediction", k=2, maxIter=20, initMode="random",\ |
| 1307 | + idCol="id", neighborsCol="neighbors", similaritiesCol="similarities"): |
| 1308 | + Sets params for PowerIterationClustering. |
| 1309 | + """ |
| 1310 | + kwargs = self._input_kwargs |
| 1311 | + return self._set(**kwargs) |
| 1312 | + |
| 1313 | + @since("2.4.0") |
| 1314 | + def setK(self, value): |
| 1315 | + """ |
| 1316 | + Sets the value of :py:attr:`k`. |
| 1317 | + """ |
| 1318 | + return self._set(k=value) |
| 1319 | + |
| 1320 | + @since("2.4.0") |
| 1321 | + def setInitMode(self, value): |
| 1322 | + """ |
| 1323 | + Sets the value of :py:attr:`initMode`. |
| 1324 | + """ |
| 1325 | + return self._set(initMode=value) |
| 1326 | + |
| 1327 | + @since("2.4.0") |
| 1328 | + def setIdCol(self, value): |
| 1329 | + """ |
| 1330 | + Sets the value of :py:attr:`idCol`. |
| 1331 | + """ |
| 1332 | + return self._set(idCol=value) |
| 1333 | + |
| 1334 | + @since("2.4.0") |
| 1335 | + def setNeighborsCol(self, value): |
| 1336 | + """ |
| 1337 | + Sets the value of :py:attr:`neighborsCol. |
| 1338 | + """ |
| 1339 | + return self._set(neighborsCol=value) |
| 1340 | + |
| 1341 | + @since("2.4.0") |
| 1342 | + def setSimilaritiesCol(self, value): |
| 1343 | + """ |
| 1344 | + Sets the value of :py:attr:`similaritiesCol`. |
| 1345 | + """ |
| 1346 | + return self._set(similaritiesCol=value) |
| 1347 | + |
| 1348 | + |
1159 | 1349 | if __name__ == "__main__": |
1160 | 1350 | import doctest |
1161 | 1351 | import pyspark.ml.clustering |
|
0 commit comments