本文介绍使用scala做机器学习模型的一个主要步骤示例。这里主要列了些基本环节,可以在此基础上进行扩充。
object mlExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("TobyGao")
.enableHiveSupport()
.getOrCreate()
val modelPath = "/user/Tobygao/model_saved"
val dataPath = "/user/Tobygao/ml_data"
//1- load data
var df = spark.read.json(dataPath+"/data/simple-ml")
//2- train/test Split
val Array(train, test) = df.randomSplit(Array(0.7, 0.3))
//3 featureVector -- VectorAssember or RFormula
val rForm = new RFormula()
//4 define model
import org.apache.spark.ml.classification.LogisticRegression
val lr = new LogisticRegression()
.setLabelCol("label")
.setFeaturesCol("features")
println(lr.explainParams())
//5- pipeline
import org.apache.spark.ml.Pipeline
val stages = Array(rForm, lr)
val pipeline = new Pipeline().setStages(stages)
//6 - ParamGridBuilder 参数构造器
import org.apache.spark.ml.tuning.ParamGridBuilder
val params = new ParamGridBuilder()
.addGrid(rForm.formula, Array(
"lab ~ . + color:value1",
"lab ~ . + color:value1 + color:value2"))
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
.addGrid(lr.regParam, Array(0.1, 2.0))
.build()
//7 - Evaluator
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC") //AUC
.setRawPredictionCol("prediction")
.setLabelCol("label")
//8 - TrainValidationSplit
import org.apache.spark.ml.tuning.TrainValidationSplit
val tvs = new TrainValidationSplit()
.setTrainRatio(0.75) // 训练集、验证集拆分比例
.setEstimatorParamMaps(params) //参数网格构造器
.setEstimator(pipeline) //估计器
.setEvaluator(evaluator) //评价器
//9 - model fit
val tvsFitted = tvs.fit(train)
//10 - model predict
val tvsPredict = tvsFitted.transform(test)
tvsPredict.show()
//11- show evaluate
tvsPredict.show()
println(evaluator.evaluate(tvsPredict)) // AUC
//12- best model
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.classification.LogisticRegressionModel
val trainedPipeline = tvsFitted.bestModel.asInstanceOf[PipelineModel]
val TrainedLR = trainedPipeline.stages(1).asInstanceOf[LogisticRegressionModel]
val summaryLR = TrainedLR.summary
println(summaryLR.objectiveHistory.mkString(",")) // 查看模型收敛速度,这是个Array,存放了每次迭代训练后的目标函数objective的值。可以通过查看这个数据来判断是否应该增大训练的迭代次数、早停或者调参优化模型
//13 - model save
tvsFitted.write.overwrite().save(modelPath+"/tmp/modelLocation")
//14 - model load
import org.apache.spark.ml.tuning.TrainValidationSplitModel
val model = TrainValidationSplitModel.load(modelPath+"/tmp/modelLocation")
model.transform(test)
}
}
结果:
1-load data
+-----+----+------+------------------+
|color| lab|value1| value2|
+-----+----+------+------------------+
|green|good| 1|14.386294994851129|
|green| bad| 16|14.386294994851129|
| blue| bad| 8|14.386294994851129|
| blue| bad| 8|14.386294994851129|
| blue| bad| 12|14.386294994851129|
|green| bad| 16|14.386294994851129|
|green|good| 12|14.386294994851129|
| red|good| 35|14.386294994851129|
| red|good| 35|14.386294994851129|
| red| bad| 2|14.386294994851129|
| red| bad| 16|14.386294994851129|
| red| bad| 16|14.386294994851129|
| blue| bad| 8|14.386294994851129|
|green|good| 1|14.386294994851129|
|green|good| 12|14.386294994851129|
| blue| bad| 8|14.386294994851129|
| red|good| 35|14.386294994851129|
| blue| bad| 12|14.386294994851129|
| red| bad| 16|14.386294994851129|
|green|good| 12|14.386294994851129|
+-----+----+------+------------------+
3-RFormula
+-----+----+------+------------------+--------------------+-----+
|color| lab|value1| value2| features|label|
+-----+----+------+------------------+--------------------+-----+
|green|good| 1|14.386294994851129|(10,[1,2,3,5,8],[...| 1.0|
| blue| bad| 8|14.386294994851129|(10,[2,3,6,9],[8....| 0.0|
| blue| bad| 12|14.386294994851129|(10,[2,3,6,9],[12...| 0.0|
|green|good| 15| 38.97187133755819|(10,[1,2,3,5,8],[...| 1.0|
|green|good| 12|14.386294994851129|(10,[1,2,3,5,8],[...| 1.0|
|green| bad| 16|14.386294994851129|(10,[1,2,3,5,8],[...| 0.0|
| red|good| 35|14.386294994851129|(10,[0,2,3,4,7],[...| 1.0|
| red| bad| 1| 38.97187133755819|(10,[0,2,3,4,7],[...| 0.0|
| red| bad| 2|14.386294994851129|(10,[0,2,3,4,7],[...| 0.0|
| red| bad| 16|14.386294994851129|(10,[0,2,3,4,7],[...| 0.0|
| red|good| 45| 38.97187133755819|(10,[0,2,3,4,7],[...| 1.0|
|green|good| 1|14.386294994851129|(10,[1,2,3,5,8],[...| 1.0|
| blue| bad| 8|14.386294994851129|(10,[2,3,6,9],[8....| 0.0|
| blue| bad| 12|14.386294994851129|(10,[2,3,6,9],[12...| 0.0|
|green|good| 15| 38.97187133755819|(10,[1,2,3,5,8],[...| 1.0|
|green|good| 12|14.386294994851129|(10,[1,2,3,5,8],[...| 1.0|
|green| bad| 16|14.386294994851129|(10,[1,2,3,5,8],[...| 0.0|
| red|good| 35|14.386294994851129|(10,[0,2,3,4,7],[...| 1.0|
| red| bad| 1| 38.97187133755819|(10,[0,2,3,4,7],[...| 0.0|
| red| bad| 2|14.386294994851129|(10,[0,2,3,4,7],[...| 0.0|
+-----+----+------+------------------+--------------------+-----+
10- model prediction
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
|color| lab|value1| value2| features|label| rawPrediction| probability|prediction|
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
| blue| bad| 8|14.386294994851129|(7,[2,3,6],[8.0,1...| 0.0|[1.81841935188104...|[0.86037635368405...| 0.0|
| blue| bad| 8|14.386294994851129|(7,[2,3,6],[8.0,1...| 0.0|[1.81841935188104...|[0.86037635368405...| 0.0|
| blue| bad| 8|14.386294994851129|(7,[2,3,6],[8.0,1...| 0.0|[1.81841935188104...|[0.86037635368405...| 0.0|
| blue| bad| 8|14.386294994851129|(7,[2,3,6],[8.0,1...| 0.0|[1.81841935188104...|[0.86037635368405...| 0.0|
| blue| bad| 8|14.386294994851129|(7,[2,3,6],[8.0,1...| 0.0|[1.81841935188104...|[0.86037635368405...| 0.0|
| blue| bad| 12|14.386294994851129|(7,[2,3,6],[12.0,...| 0.0|[2.15923553226233...|[0.89652865416576...| 0.0|
| blue| bad| 12|14.386294994851129|(7,[2,3,6],[12.0,...| 0.0|[2.15923553226233...|[0.89652865416576...| 0.0|
| blue| bad| 12|14.386294994851129|(7,[2,3,6],[12.0,...| 0.0|[2.15923553226233...|[0.89652865416576...| 0.0|
|green| bad| 16|14.386294994851129|[0.0,1.0,16.0,14....| 0.0|[-0.6607070390540...|[0.34058080292169...| 1.0|
|green| bad| 16|14.386294994851129|[0.0,1.0,16.0,14....| 0.0|[-0.6607070390540...|[0.34058080292169...| 1.0|
|green| bad| 16|14.386294994851129|[0.0,1.0,16.0,14....| 0.0|[-0.6607070390540...|[0.34058080292169...| 1.0|
|green|good| 1|14.386294994851129|[0.0,1.0,1.0,14.3...| 1.0|[-0.4860199728364...|[0.38083160751668...| 1.0|
|green|good| 1|14.386294994851129|[0.0,1.0,1.0,14.3...| 1.0|[-0.4860199728364...|[0.38083160751668...| 1.0|
|green|good| 12|14.386294994851129|[0.0,1.0,12.0,14....| 1.0|[-0.6141238213959...|[0.35111907339348...| 1.0|
|green|good| 12|14.386294994851129|[0.0,1.0,12.0,14....| 1.0|[-0.6141238213959...|[0.35111907339348...| 1.0|
|green|good| 15| 38.97187133755819|[0.0,1.0,15.0,38....| 1.0|[-1.1954765118448...|[0.23228089715736...| 1.0|
|green|good| 15| 38.97187133755819|[0.0,1.0,15.0,38....| 1.0|[-1.1954765118448...|[0.23228089715736...| 1.0|
| red| bad| 1| 38.97187133755819|[1.0,0.0,1.0,38.9...| 0.0|[1.34210087888720...|[0.79283521791024...| 0.0|
| red| bad| 1| 38.97187133755819|[1.0,0.0,1.0,38.9...| 0.0|[1.34210087888720...|[0.79283521791024...| 0.0|
| red| bad| 2|14.386294994851129|[1.0,0.0,2.0,14.3...| 0.0|[1.80828707711963...|[0.85915472458301...| 0.0|
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
11 - evaluator
AUC 0.9210526315789473
12 - objectiveHistory
0.6930670630541909,0.5961979573995868,0.5268590504745408,0.47249879942722717,0.4635853436671372,0.4548517016401766,0.4501786908817158,0.44601002558944336,0.4436249409133597,0.4416524078673581,0.4415889464730704,0.44157753890376344,0.44157750876351043,0.441577487735776,0.44157748049670753,0.4415774778777291,0.44157747771911865,0.44157747771846245
13 - model save
12.0 K /user/gaoToby/model_saved/tmp/modelLocation/bestModel
1.1 K /user/gaoToby/model_saved/tmp/modelLocation/estimator
376 /user/gaoToby/model_saved/tmp/modelLocation/evaluator
3.7 K /user/gaoToby/model_saved/tmp/modelLocation/metadata
|