diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala deleted file mode 100644 index 622b53a252ac5..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.pmml.export - -import scala.{Array => SArray} - -import org.dmg.pmml._ - -import org.apache.spark.mllib.regression.GeneralizedLinearModel - -/** - * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel - */ -private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, - description : String, - normalizationMethod : RegressionNormalizationMethodType, - threshold: Double) - extends PMMLModelExport { - - populateBinaryClassificationPMML() - - /** - * Export the input LogisticRegressionModel or SVMModel to PMML format. - */ - private def populateBinaryClassificationPMML(): Unit = { - pmml.getHeader.setDescription(description) - - if (model.weights.size > 0) { - val fields = new SArray[FieldName](model.weights.size) - val dataDictionary = new DataDictionary - val miningSchema = new MiningSchema - val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") - var interceptNO = threshold - if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { - if (threshold <= 0) { - interceptNO = Double.MinValue - } else if (threshold >= 1) { - interceptNO = Double.MaxValue - } else { - interceptNO = -math.log(1 / threshold - 1) - } - } - val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") - val regressionModel = new RegressionModel() - .withFunctionName(MiningFunctionType.CLASSIFICATION) - .withMiningSchema(miningSchema) - .withModelName(description) - .withNormalizationMethod(normalizationMethod) - .withRegressionTables(regressionTableYES, regressionTableNO) - - for (i <- 0 until model.weights.size) { - fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema - .withMiningFields(new MiningField(fields(i)) - .withUsageType(FieldUsageType.ACTIVE)) - regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) - } - - // add target field - val targetField = FieldName.create("target") - dataDictionary - .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) - miningSchema - .withMiningFields(new MiningField(targetField) - .withUsageType(FieldUsageType.TARGET)) - - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) - - pmml.setDataDictionary(dataDictionary) - pmml.withModels(regressionModel) - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/ClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/ClassificationPMMLModelExport.scala new file mode 100644 index 0000000000000..dcc1a948210db --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/ClassificationPMMLModelExport.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel class with ClassificationModel + */ +private[mllib] class ClassificationPMMLModelExport( + model : GeneralizedLinearModel, + numClasses: Int, + numFeatures: Int, + description : String, + normalizationMethod : RegressionNormalizationMethodType, + threshold: Double) + extends PMMLModelExport { + + populateClassificationPMML() + + /** + * Export the input LogisticRegressionModel or SVMModel to PMML format. + */ + private def populateClassificationPMML(): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + + val fields = new SArray[FieldName](numFeatures) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + + for (i <- 0 until numFeatures) { + fields(i) = FieldName.create("field_" + i) + dataDictionary + .withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + } + + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.CLASSIFICATION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withNormalizationMethod(normalizationMethod) + + var interceptCategoryZero = threshold + if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { + if (threshold <= 0) { + interceptCategoryZero = Double.MinValue + } else if (threshold >= 1) { + interceptCategoryZero = Double.MaxValue + } else { + interceptCategoryZero = -math.log(1 / threshold - 1) + } + } + val regressionTableCategoryZero = new RegressionTable(interceptCategoryZero) + .withTargetCategory("0") + regressionModel.withRegressionTables(regressionTableCategoryZero) + + // build binary classification + if (numClasses == 2) { + // intercept is stored in model.intercept + val regressionTableCategoryOne = new RegressionTable(model.intercept) + .withTargetCategory("1") + for (i <- 0 until numFeatures) { + regressionTableCategoryOne + .withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + regressionModel.withRegressionTables(regressionTableCategoryOne) + } else { + // build multiclass classification + for (i <- 0 until numClasses - 1) { + if (model.weights.size == (numClasses - 1) * (numFeatures + 1)) { + // intercept is stored in weights (last element) + val regressionTableCategory = new RegressionTable( + model.weights(i * (numFeatures + 1) + numFeatures)) + .withTargetCategory((i + 1).toString) + for (j <- 0 until numFeatures) { + regressionTableCategory.withNumericPredictors(new NumericPredictor(fields(j), + model.weights(i * (numFeatures + 1) + j))) + } + regressionModel.withRegressionTables(regressionTableCategory) + } else { + // intercept is zero + val regressionTableCategory = new RegressionTable(0) + .withTargetCategory((i + 1).toString) + for (j <- 0 until numFeatures) { + regressionTableCategory.withNumericPredictors(new NumericPredictor(fields(j), + model.weights(i*numFeatures + j))) + } + regressionModel.withRegressionTables(regressionTableCategory) + } + } + } + + // add target field + val targetField = FieldName.create("target") + dataDictionary + .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index 29bd689e1185a..7f1376a3f2d02 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -18,10 +18,10 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionNormalizationMethodType - import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel @@ -43,18 +43,16 @@ private[mllib] object PMMLModelExportFactory { case lasso: LassoModel => new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => - new BinaryClassificationPMMLModelExport( - svm, "linear SVM", RegressionNormalizationMethodType.NONE, + new ClassificationPMMLModelExport( + svm, 2, svm.weights.size, + "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => - if (logistic.numClasses == 2) { - new BinaryClassificationPMMLModelExport( - logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT, - logistic.getThreshold.getOrElse(0.5)) - } else { - throw new IllegalArgumentException( - "PMML Export not supported for Multinomial Logistic Regression") - } + new ClassificationPMMLModelExport( + logistic, + logistic.numClasses, logistic.numFeatures, + "logistic regression", RegressionNormalizationMethodType.LOGIT, + logistic.getThreshold.getOrElse(0.5)) case _ => throw new IllegalArgumentException( "PMML Export not supported for model: " + model.getClass.getName) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala deleted file mode 100644 index 4c6e76e47419b..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.pmml.export - -import org.dmg.pmml.RegressionModel -import org.dmg.pmml.RegressionNormalizationMethodType - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.classification.SVMModel -import org.apache.spark.mllib.util.LinearDataGenerator - -class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { - - test("logistic regression PMML export") { - val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - val logisticRegressionModel = - new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) - - val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) - - // assert that the PMML format is as expected - assert(logisticModelExport.isInstanceOf[PMMLModelExport]) - val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml - assert(pmml.getHeader.getDescription === "logistic regression") - // check that the number of fields match the weights size - assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1) - // This verify that there is a model attached to the pmml object and the model is a regression - // one. It also verifies that the pmml model has a regression table (for target category 1) - // with the same number of predictors of the model weights. - val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] - assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") - assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size - === logisticRegressionModel.weights.size) - // verify if there is a second table with target category 0 and no predictors - assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") - assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) - // ensure logistic regression has normalization method set to LOGIT - assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) - } - - test("linear SVM PMML export") { - val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - - val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - - // assert that the PMML format is as expected - assert(svmModelExport.isInstanceOf[PMMLModelExport]) - val pmml = svmModelExport.getPmml - assert(pmml.getHeader.getDescription - === "linear SVM") - // check that the number of fields match the weights size - assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) - // This verify that there is a model attached to the pmml object and the model is a regression - // one. It also verifies that the pmml model has a regression table (for target category 1) - // with the same number of predictors of the model weights. - val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] - assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") - assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size - === svmModel.weights.size) - // verify if there is a second table with target category 0 and no predictors - assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") - assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) - // ensure linear SVM has normalization method set to NONE - assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) - } - -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/ClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/ClassificationPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..b17262ccb5f6a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/ClassificationPMMLModelExportSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionNormalizationMethodType + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.LinearDataGenerator + +class ClassificationPMMLModelExportSuite extends SparkFunSuite { + + test("binary logistic regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + + // assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size + === logisticRegressionModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) + } + + test("linear SVM PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + + // assert that the PMML format is as expected + assert(svmModelExport.isInstanceOf[PMMLModelExport]) + val pmml = svmModelExport.getPmml + assert(pmml.getHeader.getDescription + === "linear SVM") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size + === svmModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size === 0) + // ensure linear SVM has normalization method set to NONE + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) + } + + test("multiclass logistic regression PMML export (wihtout intercept)") { + /** 3 classes, 2 features */ + val logisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 0.0, + numFeatures = 2, numClasses = 3) + + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + + // assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.numFeatures + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights / numFeatures. + var pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size + === logisticRegressionModel.weights.size / logisticRegressionModel.numFeatures) + // verify there is a category 2 as there are 3 classes + pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(2).getTargetCategory === "2") + assert(pmmlRegressionModel.getRegressionTables.get(2).getNumericPredictors.size + === logisticRegressionModel.weights.size / logisticRegressionModel.numFeatures) + // verify if there is a third table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) + // ensure the category 1 and 2 tables have intercept 0 + assert(pmmlRegressionModel.getRegressionTables.get(1).getIntercept() === 0) + assert(pmmlRegressionModel.getRegressionTables.get(2).getIntercept() === 0) + } + + test("multiclass logistic regression PMML export (with intercept)") { + /** 3 classes, 2 features */ + val logisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.01, 0.3, 0.4, 0.02), intercept = 0.0, + numFeatures = 2, numClasses = 3) + + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + + // assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.numFeatures + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights / numFeatures + 1. + var pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size + === logisticRegressionModel.weights.size / (logisticRegressionModel.numFeatures + 1)) + // verify there is a category 2 as there are 3 classes + pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(2).getTargetCategory === "2") + assert(pmmlRegressionModel.getRegressionTables.get(2).getNumericPredictors.size + === logisticRegressionModel.weights.size / (logisticRegressionModel.numFeatures + 1)) + // verify if there is a third table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) + // ensure the category 1 and 2 tables have intercept 0 + assert(pmmlRegressionModel.getRegressionTables.get(1).getIntercept() === 0.01) + assert(pmmlRegressionModel.getRegressionTables.get(2).getIntercept() === 0.02) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index af49450961750..a61b81732ad91 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -57,31 +57,30 @@ class PMMLModelExportFactorySuite extends SparkFunSuite { assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) } - test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " - + "when passing a LogisticRegressionModel or SVMModel") { + test("PMMLModelExportFactory create ClassificationPMMLModelExport " + + "when passing a Binary LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) - assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + assert(logisticRegressionModelExport.isInstanceOf[ClassificationPMMLModelExport]) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + assert(svmModelExport.isInstanceOf[ClassificationPMMLModelExport]) } - test("PMMLModelExportFactory throw IllegalArgumentException " - + "when passing a Multinomial Logistic Regression") { + test("PMMLModelExportFactory create ClassificationPMMLModelExport " + + "when passing a Multiclass Logistic Regression") { /** 3 classes, 2 features */ val multiclassLogisticRegressionModel = new LogisticRegressionModel( - weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, - numFeatures = 2, numClasses = 3) - - intercept[IllegalArgumentException] { - PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) - } + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + numFeatures = 2, numClasses = 3) + val multiclassLogisticRegressionModelExport = PMMLModelExportFactory + .createPMMLModelExport(multiclassLogisticRegressionModel) + assert(multiclassLogisticRegressionModelExport.isInstanceOf[ClassificationPMMLModelExport]) } test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 54a9ad956d119..c1fe0b33b4e8e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -76,7 +76,10 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) ++ Seq( ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup"), + // SPARK-11401 Superseded by generic ClassificationPMMLModelExport + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.pmml.export.BinaryClassificationPMMLModelExport") ) ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.regression.LeastSquaresAggregator.add"),