第3例 使用Spark ML的逻辑回归来预测音乐标签
3.1 数据准备
3.1.1 数据集文件准备
-
(1) 该项目并为使用数据库当做数据源,而是直接将数据文件放在项目目录中, 这是一个结构化的简化数据集。 -
(2) 本项目使用的数据集是著名的 MNIST 数据集,该数据集包含 780 个特征。数据集地址: 百万歌曲数据集。
2.1.2 数据集字段解释
2.2 使用 Spark ML 实现代码
2.2.1 引入项目依赖
使用的依赖包多数来自于 Spark ML , 而非 Spark MLlib 。
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession
2.2.2 将 MNIST 数据集以 libsvm 格式进行加载并解析
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "datas3/mnist.bz2")
2.2.3 准备训练和测试集
val splits = data.randomSplit(Array(0.75, 0.25), 12345L)
val training = splits(0).cache()
val test = splits(1)
2.2.4 运行训练算法来创建模型
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.setIntercept(true)
.setValidateData(true)
.run(training)
- 到这一步, 预测模型便已经创建成功, 后续只需要根据这个模型进行预测即可。
2.2.5 在测试上计算原始分数
val scoreAndLabels = test.map{
point => {
val score = model.predict(point.features)
(score, point.label)
}
}
- 到这一步,预测结果也几经的出来了,只需要循环遍历输出一下即可,预测结果如下图所示:
- 从上图中可以看出: 预测出来的
prediction 与 label 完全一致, 说明预测的准确率是很高的。 - 至此, 预测工作已经进行结束了, 剩下还有一些 观察训练过程 和 模型评估 的操作。
2.2.6 为模型评估初始化一个多类度量
val metrics = new MulticlassMetrics(scoreAndLabels)
2.2.7 构造混淆矩阵
println("Confusion matrix: ")
println(metrics.confusionMatrix)
混淆矩阵如下图所示:
2.2.8 总体统计信息
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
val labels = metrics.labels
labels.foreach(
l => println(s"Precision($l) = " + metrics.precision(l))
)
labels.foreach(
l => println(s"Recall($l) = " + metrics.recall(l))
)
labels.foreach(
l => println(s"FPR($l) = " + metrics.falsePositiveRate(l))
)
labels.foreach(
l => println(s"F1-Score($l) = " + metrics.fMeasure(l))
)
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
上述代码的输出信息如下图所示:
2.2.9 项目完整代码
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession
object SparkML_0105_test5 {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkML")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "datas3/mnist.bz2")
val splits = data.randomSplit(Array(0.75, 0.25), 12345L)
val training = splits(0).cache()
val test = splits(1)
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.setIntercept(true)
.setValidateData(true)
.run(training)
model.clearThreshold()
val scoreAndLabels = test.map{
point => {
val score = model.predict(point.features)
(score, point.label)
}
}
val metrics = new MulticlassMetrics(scoreAndLabels)
println("Confusion matrix: ")
println(metrics.confusionMatrix)
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
val labels = metrics.labels
labels.foreach(
l => println(s"Precision($l) = " + metrics.precision(l))
)
labels.foreach(
l => println(s"Recall($l) = " + metrics.recall(l))
)
labels.foreach(
l => println(s"FPR($l) = " + metrics.falsePositiveRate(l))
)
labels.foreach(
l => println(s"F1-Score($l) = " + metrics.fMeasure(l))
)
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
spark.close()
}
}
|