**package** mllib.tree
**import** org.apache.log4j.{Level, Logger}
**import** org.apache.spark.mllib.evaluation.MulticlassMetrics
**import** org.apache.spark.mllib.linalg.Vectors
**import** org.apache.spark.mllib.regression.LabeledPoint
**import** org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
**import** org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
**import** org.apache.spark.rdd.RDD
**import** org.apache.spark.{SparkContext, SparkConf}
_/**_ _* Created by_ _汪本成_ _on 2016/7/18._ _*/_ **object** randomForest {
//屏蔽不必要的日志显示在终端上
// Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
// Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
**var** _beg_ = System.currentTimeMillis()
//创建入口对象
**val** _conf_ = **new** SparkConf().setAppName("rndomForest").setMaster("local")
**val** _sc_ = **new** SparkContext( _conf_ )
**val** _HDFS_COVDATA_PATH_ = "hdfs://192.168.43.150:9000/user/spark/sparkLearning/mllib/covtype.data"
**val** _rawData_ = _sc_.textFile( _HDFS_COVDATA_PATH_ )
//设置LabeledPoint格式
**val** _data_ = _rawData_.map{
line =>
**val** values = line.split(",").map(_.toDouble)
// init返回除最后一个值之外的所有值,最后一列是目标
**val** FeatureVector = Vectors.dense(values.init)
//决策树要求(目标变量)label从0开始,所以要减一
**val** label = values.last - 1
LabeledPoint(label, FeatureVector)
}
//分成训练集(80%),交叉验证集(10%),测试集(10%)
**val** Array( _trainData_ , _cvData_ , _testData_ ) = _data_.randomSplit(Array(0.8, 0.1, 0.1))
_trainData_.cache()
_cvData_.cache()
_testData_.cache()
//新建随机森林
**val** _numClass_ = 7 //分类数量
**val** _categoricalFeaturesInfo_ = _Map_ [Int, Int](10 -> 4, 11-> 40) //用map存储类别(离散)特征及每个类特征对应值(类别)的数量
**val** _impurity_ = "entropy" //纯度计算方法,用于信息增益的计算
**val** _number_ = 20 //构建树的数量
**val** _maxDepth_ = 4 //树的最大高度
**val** _maxBins_ = 100 // 用于分裂特征的最大划分数量
//训练分类决策树模型
**val** _model_ = RandomForest.trainClassifier( _trainData_ , _numClass_ , _categoricalFeaturesInfo_ , _number_ , "auto", _impurity_ , _maxDepth_ , _maxBins_ )
**val** _metrics_ = getMetrics( _model_ , _cvData_ )
//计算精确度(样本比例)
**val** _precision_ = _metrics_. _precision_ __ //计算每个样本的准确度(召回率)
**val** _recall_ = (0 until 7).map( //DecisionTreeModel模型的类别号从0开始
cat => ( _metrics_.precision(cat), _metrics_.recall(cat))
)
**val** _end_ = System.currentTimeMillis()
//耗时时间
**var** _castTime_ = _end_ - _beg_ ____**def** main(args: Array[String]) {
println("========================================================================================")
//精确度(样本比例)
println("精确度: " + _precision_ )
println("========================================================================================")
//准确度(召回率)
println("准确度: ")
_recall_.foreach(println)
println("========================================================================================")
println(" 运行程序耗时: " + _castTime_ /1000 + "s")
}
_/**_ _*_ _在训练集构建RandomForestModel_ ___*_ ** _@param model_** ** __**_*_ ** _@param data_** ** __**_*_ ** _@return_** ** __**_*/_ __**def** getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
**val** predictionsAndLabels = data.map(example => (model.predict(example.features), example.label))
**new** MulticlassMetrics(predictionsAndLabels)
}
_/**_ _*_ _按照类别在训练集出现的比例预测类别_ ___*_ _*_ ** _@param data_** ** __**_*_ ** _@return_** ** __**_*/_ __**def** classProbabilities(data: RDD[LabeledPoint]): Array[Double] = {
//计算数据中每个类别的样本数(类别, 样本数)
**val** countsByCategory = data.map(_.label).countByValue()
//对类别的样本数进行排序并取出样本数
**val** counts = countsByCategory.toArray.sortBy(_._1).map(_._2)
counts.map(_.toDouble / counts.sum)
}
}
[/code]
运行结果如下
16/07/18 23:30:11 INFO DAGScheduler: ResultStage 17 (collectAsMap at MulticlassMetrics.scala:54) finished in 0.003 s 16/07/18 23:30:11 INFO TaskSchedulerImpl: Removed TaskSet 17.0, whose tasks have all completed, from pool 16/07/18 23:30:11 INFO DAGScheduler: Job 9 finished: collectAsMap at MulticlassMetrics.scala:54, took 0.197033 s
精确度: 0.5307208847065288
准确度: (0.8087885985748219,0.03206818609907704) (0.5233824352041768,0.9884494841004331) (0.5730994152046783,0.6121521862578081) (0.0,0.0) (0.0,0.0) (0.0,0.0) (0.0,0.0)
运行程序耗时: 44s 16/07/18 23:30:12 INFO SparkContext: Invoking stop() from shutdown hook 16/07/18 23:30:12 INFO SparkUI: Stopped Spark web UI at http://192.168.43.1:4040 16/07/18 23:30:12 INFO MapOutputTrackerMasterEndpoint: MapOutputTrackerMasterEndpoint stopped! 16/07/18 23:30:12 INFO MemoryStore: MemoryStore cleared 16/07/18 23:30:12 INFO BlockManager: BlockManager stopped 16/07/18 23:30:12 INFO BlockManagerMaster: BlockManagerMaster stopped 16/07/18 23:30:12 INFO OutputCommitCoordinatorKaTeX parse error: Undefined control sequence: \Users at position 274: …ng directory C:\?U?s?e?r?s?\Administrator\…RemotingTerminator: Shutting down remote daemon. 16/07/18 23:30:12 INFO RemoteActorRefProvider
R
e
m
o
t
i
n
g
T
e
r
m
i
n
a
t
o
r
:
R
e
m
o
t
e
d
a
e
m
o
n
s
h
u
t
d
o
w
n
;
p
r
o
c
e
e
d
i
n
g
w
i
t
h
f
l
u
s
h
i
n
g
r
e
m
o
t
e
t
r
a
n
s
p
o
r
t
s
.
16
/
07
/
1823
:
30
:
12
I
N
F
O
R
e
m
o
t
e
A
c
t
o
r
R
e
f
P
r
o
v
i
d
e
r
RemotingTerminator: Remote daemon shut down; proceeding with flushing remote transports. 16/07/18 23:30:12 INFO RemoteActorRefProvider
RemotingTerminator:Remotedaemonshutdown;proceedingwithflushingremotetransports.16/07/1823:30:12INFORemoteActorRefProviderRemotingTerminator: Remoting shut down.
Process finished with exit code 0

|