代码如下:
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils val updater = params.regType match { case NONE => new SimpleUpdater() case L1 => new L1Updater() case L2 => new SquaredL2Updater() } val algorithm = new LinearRegressionWithSGD() algorithm.optimizer .setNumIterations(params.numIterations) .setStepSize(params.stepSize) .setUpdater(updater) .setRegParam(params.regParam) val model = algorithm.run(training) val prediction = model.predict(test.map(_.features)) val predictionAndLabel = prediction.zip(test.map(_.label)) // Compute raw scores on the test set. val scoreAndLabels = test.map { case LabeledPoint(label, features) => val prediction = model.predict(features) (prediction, label) } // Get evaluation metrics. val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() val auPR = metrics.areaUnderPR() metrics.precisionByThreshold().collect().foreach(println) metrics.recallByThreshold().collect().foreach(println) metrics.fMeasureByThreshold().collect().foreach(println)