Skip to content

Commit da2ec11

Browse files
committed
[SPARK-1406] added linear SVM PMML export
1 parent 82f2131 commit da2ec11

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.export
1919

20+
import org.apache.spark.mllib.classification.SVMModel
2021
import org.apache.spark.mllib.clustering.KMeansModel
2122
import org.apache.spark.mllib.export.ModelExportType.ModelExportType
2223
import org.apache.spark.mllib.export.ModelExportType.PMML
@@ -44,6 +45,8 @@ private[mllib] object ModelExportFactory {
4445
new GeneralizedLinearPMMLModelExport(ridgeRegression, "ridge regression")
4546
case lassoRegression: LassoModel =>
4647
new GeneralizedLinearPMMLModelExport(lassoRegression, "lasso regression")
48+
case svm: SVMModel =>
49+
new GeneralizedLinearPMMLModelExport(svm, "linear SVM")
4750
case _ =>
4851
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
4952
}

mllib/src/test/scala/org/apache/spark/mllib/export/ModelExportFactorySuite.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.export
1919

2020
import org.scalatest.FunSuite
2121

22+
import org.apache.spark.mllib.classification.SVMModel
2223
import org.apache.spark.mllib.clustering.KMeansModel
2324
import org.apache.spark.mllib.linalg.Vectors
2425
import org.apache.spark.mllib.regression.LassoModel
@@ -48,15 +49,16 @@ class ModelExportFactorySuite extends FunSuite{
4849

4950
}
5051

51-
test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a"
52-
+"LinearRegressionModel, RidgeRegressionModel or LassoModel") {
52+
test("ModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
53+
+"LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
5354

5455
//arrange
5556
val linearInput = LinearDataGenerator.generateLinearInput(
5657
3.0, Array(10.0, 10.0), 1, 17)
5758
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
5859
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
5960
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
61+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);
6062

6163
//act
6264
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
@@ -73,6 +75,11 @@ class ModelExportFactorySuite extends FunSuite{
7375
//assert
7476
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
7577

78+
//act
79+
val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML)
80+
//assert
81+
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
82+
7683
}
7784

7885
test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {

mllib/src/test/scala/org/apache/spark/mllib/export/pmml/GeneralizedLinearPMMLModelExportSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.export.pmml
2020
import org.dmg.pmml.RegressionModel
2121
import org.scalatest.FunSuite
2222

23+
import org.apache.spark.mllib.classification.SVMModel
2324
import org.apache.spark.mllib.export.ModelExportFactory
2425
import org.apache.spark.mllib.export.ModelExportType
2526
import org.apache.spark.mllib.regression.LassoModel
@@ -37,6 +38,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
3738
val linearRegressionModel = new LinearRegressionModel(linearInput(0).features, linearInput(0).label);
3839
val ridgeRegressionModel = new RidgeRegressionModel(linearInput(0).features, linearInput(0).label);
3940
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label);
41+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);
4042

4143
//act by exporting the model to the PMML format
4244
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
@@ -76,11 +78,25 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
7678
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
7779
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
7880
.getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size)
81+
82+
//act
83+
val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML)
84+
//assert that the PMML format is as expected
85+
assert(svmModelExport.isInstanceOf[PMMLModelExport])
86+
pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml()
87+
assert(pmml.getHeader().getDescription() === "linear SVM")
88+
//check that the number of fields match the weights size
89+
assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1)
90+
//this verify that there is a model attached to the pmml object and the model is a regression one
91+
//it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
92+
assert(pmml.getModels().get(0).asInstanceOf[RegressionModel]
93+
.getRegressionTables().get(0).getNumericPredictors().size() === svmModel.weights.size)
7994

8095
//manual checking
8196
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
8297
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
8398
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
99+
//ModelExporter.toPMML(svmModel,"/tmp/svm.xml")
84100

85101
}
86102

0 commit comments

Comments
 (0)