@@ -20,6 +20,7 @@ package org.apache.spark.mllib.export.pmml
2020import org .dmg .pmml .RegressionModel
2121import org .scalatest .FunSuite
2222
23+ import org .apache .spark .mllib .classification .SVMModel
2324import org .apache .spark .mllib .export .ModelExportFactory
2425import org .apache .spark .mllib .export .ModelExportType
2526import org .apache .spark .mllib .regression .LassoModel
@@ -37,6 +38,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
3738 val linearRegressionModel = new LinearRegressionModel (linearInput(0 ).features, linearInput(0 ).label);
3839 val ridgeRegressionModel = new RidgeRegressionModel (linearInput(0 ).features, linearInput(0 ).label);
3940 val lassoModel = new LassoModel (linearInput(0 ).features, linearInput(0 ).label);
41+ val svmModel = new SVMModel (linearInput(0 ).features, linearInput(0 ).label);
4042
4143 // act by exporting the model to the PMML format
4244 val linearModelExport = ModelExportFactory .createModelExport(linearRegressionModel, ModelExportType .PMML )
@@ -76,11 +78,25 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
7678 // it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
7779 assert(pmml.getModels().get(0 ).asInstanceOf [RegressionModel ]
7880 .getRegressionTables().get(0 ).getNumericPredictors().size() === lassoModel.weights.size)
81+
82+ // act
83+ val svmModelExport = ModelExportFactory .createModelExport(svmModel, ModelExportType .PMML )
84+ // assert that the PMML format is as expected
85+ assert(svmModelExport.isInstanceOf [PMMLModelExport ])
86+ pmml = svmModelExport.asInstanceOf [PMMLModelExport ].getPmml()
87+ assert(pmml.getHeader().getDescription() === " linear SVM" )
88+ // check that the number of fields match the weights size
89+ assert(pmml.getDataDictionary().getNumberOfFields() === svmModel.weights.size + 1 )
90+ // this verify that there is a model attached to the pmml object and the model is a regression one
91+ // it also verifies that the pmml model has a regression table with the same number of predictors of the model weights
92+ assert(pmml.getModels().get(0 ).asInstanceOf [RegressionModel ]
93+ .getRegressionTables().get(0 ).getNumericPredictors().size() === svmModel.weights.size)
7994
8095 // manual checking
8196 // ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
8297 // ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
8398 // ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
99+ // ModelExporter.toPMML(svmModel,"/tmp/svm.xml")
84100
85101 }
86102
0 commit comments