1818package org .apache .spark .mllib .export .pmml
1919
2020import org .apache .spark .mllib .clustering .KMeansModel
21+ import org .dmg .pmml .DataDictionary
22+ import org .dmg .pmml .FieldName
23+ import org .dmg .pmml .DataField
24+ import org .dmg .pmml .OpType
25+ import org .dmg .pmml .DataType
26+ import org .dmg .pmml .MiningSchema
27+ import org .dmg .pmml .MiningField
28+ import org .dmg .pmml .FieldUsageType
29+ import org .dmg .pmml .ComparisonMeasure
30+ import org .dmg .pmml .ComparisonMeasure .Kind
31+ import org .dmg .pmml .SquaredEuclidean
32+ import org .dmg .pmml .ClusteringModel
33+ import org .dmg .pmml .MiningFunctionType
34+ import org .dmg .pmml .ClusteringModel .ModelClass
35+ import org .dmg .pmml .ClusteringField
36+ import org .dmg .pmml .CompareFunctionType
37+ import org .dmg .pmml .Cluster
38+ import org .dmg .pmml .Array .Type
2139
2240/**
2341 * PMML Model Export for KMeansModel class
@@ -30,9 +48,48 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
3048 populateKMeansPMML(model);
3149
3250 private def populateKMeansPMML (model : KMeansModel ): Unit = {
33- // TODO: set here header description
34- pmml.setVersion(" testing... kmeans..." );
35- // TODO: generate the model...
51+
52+ pmml.getHeader().setDescription(" k-means clustering" );
53+
54+ if (model.clusterCenters.length > 0 ){
55+
56+ val clusterCenter = model.clusterCenters(0 )
57+
58+ var fields = new Array [FieldName ](clusterCenter.size)
59+
60+ var dataDictionary = new DataDictionary ()
61+
62+ var miningSchema = new MiningSchema ()
63+
64+ for ( i <- 0 to (clusterCenter.size - 1 )) {
65+ fields(i) = FieldName .create(" field_" + i)
66+ dataDictionary.withDataFields(new DataField (fields(i), OpType .CONTINUOUS , DataType .DOUBLE ))
67+ miningSchema.withMiningFields(new MiningField (fields(i)).withUsageType(FieldUsageType .ACTIVE ))
68+ }
69+
70+ var comparisonMeasure = new ComparisonMeasure ()
71+ .withKind(Kind .DISTANCE )
72+ .withMeasure(new SquaredEuclidean ()
73+ );
74+
75+ dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size());
76+
77+ pmml.setDataDictionary(dataDictionary);
78+
79+ var clusteringModel = new ClusteringModel (miningSchema, comparisonMeasure, MiningFunctionType .CLUSTERING , ModelClass .CENTER_BASED , model.clusterCenters.length)
80+ .withModelName(" k-means" );
81+
82+ for ( i <- 0 to (clusterCenter.size - 1 )) {
83+ clusteringModel.withClusteringFields(new ClusteringField (fields(i)).withCompareFunction(CompareFunctionType .ABS_DIFF ))
84+ var cluster = new Cluster ().withName(" cluster_" + i).withArray(new org.dmg.pmml.Array ().withType(Type .REAL ).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString(" " )))
85+ // cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue)
86+ clusteringModel.withClusters(cluster)
87+ }
88+
89+ pmml.withModels(clusteringModel);
90+
91+ }
92+
3693 }
3794
3895}
0 commit comments