IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用mllib完成mnist手写识别任务 -> 正文阅读

[人工智能]使用mllib完成mnist手写识别任务

使用mllib完成mnist手写识别任务

  1. 小提示,通过restart命令重启已经退出了的容器

    sudo docker restart <contain id>

    请添加图片描述

  2. 完成识别任务准备工作

    1. 从以下网站下载数据集:

      MNIST手写数字数据库,Yann LeCun,Corinna Cortes和Chris Burges

      数据集包含以下四个压缩包,下载后解压得到数据集文件:

      • t10k-images-idx3-ubyte.gz
      • t10k-labels-idx1-ubyte.gz
      • train-images-idx3-ubyte.gz
      • train-labels-idx1-ubyte.gz
    2. 通过以下python程序,将数据集文件转换为csv文件

      def convert(imgf, labelf, outf, n):
          f = open(imgf, "rb")
          o = open(outf, "w")
          l = open(labelf, "rb")
      
          f.read(16)
          l.read(8)
          images = []
      
          for i in range(n):
              image = [ord(l.read(1))]
              for j in range(28 * 28):
                  image.append(ord(f.read(1)))
              images.append(image)
      
          for image in images:
              o.write(",".join(str(pix) for pix in image) + "\n")
          f.close()
          o.close()
          l.close()
      
      
      # 数据集在 http://yann.lecun.com/exdb/mnist/ 下载
      convert("train-images.idx3-ubyte", "train-labels.idx1-ubyte",
              "mnist_train.csv", 60000)
      convert("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte",
              "mnist_test.csv", 10000)
      

      通过这个程序将在根目录下产生以下两个文件:

      • mnist_train.csv
      • mnist_test.csv
    3. 通过以下python程序转换csv文件为libsvm文件

      import csv
      
      
      def execute(data, savepath):
      
          csv_reader = csv.reader(open(data))
          f = open(savepath, 'wb')
          for line in csv_reader:
              label = line[0]
              features = line[1:]
              libsvm_line = label + ' '
      
              for index, feature in enumerate(features):
                  libsvm_line += str(index + 1) + ':' + feature + ' '
              f.write(bytes(libsvm_line.strip() + '\n', 'UTF-8'))
      
          f.close()
      
      
      execute('mnist_train.csv', 'mnist_train.libsvm')
      execute('mnist_test.csv', 'mnist_test.libsvm')
      

      该程序将生成以下两个.libsvm文件:

      • mnist_test.libsvm
      • mnist_train.libsvm
    4. 通过共享目录传递数据集到spark-master容器内。

    5. 进入spark-master

      sudo docker exec -it spark-master /bin/bash

      请添加图片描述

    6. 打开spark-shell

      spark-shell位于/spark/bin目录下

      使用./spark-shell命令进入spark-shell。

      请添加图片描述

  3. 完成识别任务

    1. 读取训练集

      val train = spark.read.format("libsvm").load("/data/mnist_train.libsvm")
      

      请添加图片描述

    2. 读取测试集

      val test = 		spark.read.format("libsvm").load("/data/mnist_test.libsvm")
      

      请添加图片描述

    3. 定义网络结构。如果计算机性能不好可以降低隐藏层的参数。

      val layers = Array[Int](784, 784, 784, 10)
      

      请添加图片描述

    4. 导入多层感知机与多分类评价器。

      import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
      import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
      

      请添加图片描述

    5. 使用多层感知机初始化训练器。

      val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100)
      

      请添加图片描述

    6. 训练模型

      var model = trainer.fit(train)
      

      请添加图片描述

      请添加图片描述

    7. 输入测试集进行识别

      val result = model.transform(test)
      

      请添加图片描述

    8. 获取测试结果中的预测结果与实际结果

      val predictionAndLabels = result.select("prediction", "label")
      

      请添加图片描述

    9. 初始化评价器

      val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
      

      请添加图片描述

    10. 计算识别精度

      println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}")
      

      请添加图片描述

    11. 在result上创建临时视图

      result.toDF.createOrReplaceTempView("deep_learning")
      

      请添加图片描述

    12. 使用Spark SQL的方式计算识别精度

      spark.sql("select (select count(*) from deep_learning where label=prediction)/count(*) as accuracy from deep_learning").show()
      

      请添加图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-05 11:19:11  更:2022-05-05 11:23:23 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/4 16:02:03-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码