IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用scala做机器学习模型主要步骤示例 -> 正文阅读

[人工智能]使用scala做机器学习模型主要步骤示例

本文介绍使用scala做机器学习模型的一个主要步骤示例。这里主要列了些基本环节,可以在此基础上进行扩充。

object mlExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TobyGao")
      .enableHiveSupport()
      .getOrCreate()
    val modelPath = "/user/Tobygao/model_saved"
    val dataPath = "/user/Tobygao/ml_data"

    //1- load data
    var df = spark.read.json(dataPath+"/data/simple-ml")

    //2- train/test Split 
    val Array(train, test) = df.randomSplit(Array(0.7, 0.3))
   
   
    //3 featureVector -- VectorAssember or RFormula
    val rForm = new RFormula()
   
     //4 define model
    import org.apache.spark.ml.classification.LogisticRegression
    val lr = new LogisticRegression()
            .setLabelCol("label")
            .setFeaturesCol("features")
    println(lr.explainParams())

    //5- pipeline
    import org.apache.spark.ml.Pipeline
    val stages = Array(rForm, lr)
    val pipeline = new Pipeline().setStages(stages)

    //6 - ParamGridBuilder 参数构造器
    import org.apache.spark.ml.tuning.ParamGridBuilder
    val params = new ParamGridBuilder()
      .addGrid(rForm.formula, Array(
        "lab ~ . + color:value1",
        "lab ~ . + color:value1 + color:value2"))
      .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
      .addGrid(lr.regParam, Array(0.1, 2.0))
      .build()

    //7 - Evaluator
    import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    val evaluator = new BinaryClassificationEvaluator()
      .setMetricName("areaUnderROC") //AUC
      .setRawPredictionCol("prediction")
      .setLabelCol("label")

    //8 - TrainValidationSplit
    import org.apache.spark.ml.tuning.TrainValidationSplit
    val tvs = new TrainValidationSplit()
      .setTrainRatio(0.75) // 训练集、验证集拆分比例
      .setEstimatorParamMaps(params) //参数网格构造器
      .setEstimator(pipeline) //估计器
      .setEvaluator(evaluator)  //评价器

    //9 - model fit
    val tvsFitted = tvs.fit(train)

    //10 - model predict
    val tvsPredict = tvsFitted.transform(test)
    tvsPredict.show()

    //11- show evaluate
    tvsPredict.show()
    println(evaluator.evaluate(tvsPredict)) // AUC

    //12- best model
    import org.apache.spark.ml.PipelineModel
    import org.apache.spark.ml.classification.LogisticRegressionModel
    val trainedPipeline = tvsFitted.bestModel.asInstanceOf[PipelineModel]
    val TrainedLR = trainedPipeline.stages(1).asInstanceOf[LogisticRegressionModel]
    val summaryLR = TrainedLR.summary
    println(summaryLR.objectiveHistory.mkString(",")) // 查看模型收敛速度,这是个Array,存放了每次迭代训练后的目标函数objective的值。可以通过查看这个数据来判断是否应该增大训练的迭代次数、早停或者调参优化模型

    //13 - model save
    tvsFitted.write.overwrite().save(modelPath+"/tmp/modelLocation")

    //14 - model load
    import org.apache.spark.ml.tuning.TrainValidationSplitModel
    val model = TrainValidationSplitModel.load(modelPath+"/tmp/modelLocation")
    model.transform(test)

  }
}

结果:

1-load data
+-----+----+------+------------------+
|color| lab|value1|            value2|
+-----+----+------+------------------+
|green|good|     1|14.386294994851129|
|green| bad|    16|14.386294994851129|
| blue| bad|     8|14.386294994851129|
| blue| bad|     8|14.386294994851129|
| blue| bad|    12|14.386294994851129|
|green| bad|    16|14.386294994851129|
|green|good|    12|14.386294994851129|
|  red|good|    35|14.386294994851129|
|  red|good|    35|14.386294994851129|
|  red| bad|     2|14.386294994851129|
|  red| bad|    16|14.386294994851129|
|  red| bad|    16|14.386294994851129|
| blue| bad|     8|14.386294994851129|
|green|good|     1|14.386294994851129|
|green|good|    12|14.386294994851129|
| blue| bad|     8|14.386294994851129|
|  red|good|    35|14.386294994851129|
| blue| bad|    12|14.386294994851129|
|  red| bad|    16|14.386294994851129|
|green|good|    12|14.386294994851129|
+-----+----+------+------------------+ 

3-RFormula
+-----+----+------+------------------+--------------------+-----+
|color| lab|value1|            value2|            features|label|
+-----+----+------+------------------+--------------------+-----+
|green|good|     1|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
| blue| bad|     8|14.386294994851129|(10,[2,3,6,9],[8....|  0.0|
| blue| bad|    12|14.386294994851129|(10,[2,3,6,9],[12...|  0.0|
|green|good|    15| 38.97187133755819|(10,[1,2,3,5,8],[...|  1.0|
|green|good|    12|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
|green| bad|    16|14.386294994851129|(10,[1,2,3,5,8],[...|  0.0|
|  red|good|    35|14.386294994851129|(10,[0,2,3,4,7],[...|  1.0|
|  red| bad|     1| 38.97187133755819|(10,[0,2,3,4,7],[...|  0.0|
|  red| bad|     2|14.386294994851129|(10,[0,2,3,4,7],[...|  0.0|
|  red| bad|    16|14.386294994851129|(10,[0,2,3,4,7],[...|  0.0|
|  red|good|    45| 38.97187133755819|(10,[0,2,3,4,7],[...|  1.0|
|green|good|     1|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
| blue| bad|     8|14.386294994851129|(10,[2,3,6,9],[8....|  0.0|
| blue| bad|    12|14.386294994851129|(10,[2,3,6,9],[12...|  0.0|
|green|good|    15| 38.97187133755819|(10,[1,2,3,5,8],[...|  1.0|
|green|good|    12|14.386294994851129|(10,[1,2,3,5,8],[...|  1.0|
|green| bad|    16|14.386294994851129|(10,[1,2,3,5,8],[...|  0.0|
|  red|good|    35|14.386294994851129|(10,[0,2,3,4,7],[...|  1.0|
|  red| bad|     1| 38.97187133755819|(10,[0,2,3,4,7],[...|  0.0|
|  red| bad|     2|14.386294994851129|(10,[0,2,3,4,7],[...|  0.0|
+-----+----+------+------------------+--------------------+-----+ 

10- model prediction
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
|color| lab|value1|            value2|            features|label|       rawPrediction|         probability|prediction|
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|     8|14.386294994851129|(7,[2,3,6],[8.0,1...|  0.0|[1.81841935188104...|[0.86037635368405...|       0.0|
| blue| bad|    12|14.386294994851129|(7,[2,3,6],[12.0,...|  0.0|[2.15923553226233...|[0.89652865416576...|       0.0|
| blue| bad|    12|14.386294994851129|(7,[2,3,6],[12.0,...|  0.0|[2.15923553226233...|[0.89652865416576...|       0.0|
| blue| bad|    12|14.386294994851129|(7,[2,3,6],[12.0,...|  0.0|[2.15923553226233...|[0.89652865416576...|       0.0|
|green| bad|    16|14.386294994851129|[0.0,1.0,16.0,14....|  0.0|[-0.6607070390540...|[0.34058080292169...|       1.0|
|green| bad|    16|14.386294994851129|[0.0,1.0,16.0,14....|  0.0|[-0.6607070390540...|[0.34058080292169...|       1.0|
|green| bad|    16|14.386294994851129|[0.0,1.0,16.0,14....|  0.0|[-0.6607070390540...|[0.34058080292169...|       1.0|
|green|good|     1|14.386294994851129|[0.0,1.0,1.0,14.3...|  1.0|[-0.4860199728364...|[0.38083160751668...|       1.0|
|green|good|     1|14.386294994851129|[0.0,1.0,1.0,14.3...|  1.0|[-0.4860199728364...|[0.38083160751668...|       1.0|
|green|good|    12|14.386294994851129|[0.0,1.0,12.0,14....|  1.0|[-0.6141238213959...|[0.35111907339348...|       1.0|
|green|good|    12|14.386294994851129|[0.0,1.0,12.0,14....|  1.0|[-0.6141238213959...|[0.35111907339348...|       1.0|
|green|good|    15| 38.97187133755819|[0.0,1.0,15.0,38....|  1.0|[-1.1954765118448...|[0.23228089715736...|       1.0|
|green|good|    15| 38.97187133755819|[0.0,1.0,15.0,38....|  1.0|[-1.1954765118448...|[0.23228089715736...|       1.0|
|  red| bad|     1| 38.97187133755819|[1.0,0.0,1.0,38.9...|  0.0|[1.34210087888720...|[0.79283521791024...|       0.0|
|  red| bad|     1| 38.97187133755819|[1.0,0.0,1.0,38.9...|  0.0|[1.34210087888720...|[0.79283521791024...|       0.0|
|  red| bad|     2|14.386294994851129|[1.0,0.0,2.0,14.3...|  0.0|[1.80828707711963...|[0.85915472458301...|       0.0|
+-----+----+------+------------------+--------------------+-----+--------------------+--------------------+----------+ 

11 - evaluator
AUC 0.9210526315789473 

12 - objectiveHistory
0.6930670630541909,0.5961979573995868,0.5268590504745408,0.47249879942722717,0.4635853436671372,0.4548517016401766,0.4501786908817158,0.44601002558944336,0.4436249409133597,0.4416524078673581,0.4415889464730704,0.44157753890376344,0.44157750876351043,0.441577487735776,0.44157748049670753,0.4415774778777291,0.44157747771911865,0.44157747771846245

13 - model save
12.0 K  /user/gaoToby/model_saved/tmp/modelLocation/bestModel
1.1 K   /user/gaoToby/model_saved/tmp/modelLocation/estimator
376     /user/gaoToby/model_saved/tmp/modelLocation/evaluator
3.7 K   /user/gaoToby/model_saved/tmp/modelLocation/metadata
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-05 17:21:26  更:2021-08-05 17:21:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/7 14:49:00-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码