IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> mini-batch学习 -> 正文阅读

[Python知识库]mini-batch学习

mini-batch学习

交叉熵误差
在这里插入图片描述

'''
在训练数据过多的情况下,全部进行学习这是不现实的,会耗费大量时间和算力。因此我们从中选取一部分数据进行训练,这种方法被称为
mini-batch.
'''
#读取MNIST数据集代码(数据集代码位于dataset.mnist中)
import sys,os
sys.path.append('D:/pycharm/SAVE_FILE/pythonProject/【源代码】深度学习入门:基于Python的理论与实现')#添加自己的引用模块搜索目录。
"""
对于需要引用的模块和需要执行的脚本文件不在同一个目录时,可以按照如下形式来添加路径:
①导入的XX包在另一个项目文件中,在自己写的程序中需要用到XX包。
②所以我们在运行自己写的程序时,首先加载导入的XX包,加载的时候python解释器会去sys.path默认搜索路径去搜索。
③如果通过sys.path中的路径可以搜索到XX包,然后加载。
④如果无法通过sys.path中的路径搜索到XX包,即说明自己的程序中引用的XX包,与自己程序脚本所在目录不在同一个路径。(无法在自己的程序脚本中根据默认搜索路径查找到XX包)
⑤然后我们就需要将XX包的搜索路径添加到自己程序脚本的默认搜索路径中,重新运行自己的程序脚本,先搜索XX包在加载XX包。
"""
import numpy as np
from dataset.mnist import load_mnist               #导入dataset文件夹下的load_mnist文件
(x_train,t_train),(x_test,t_test)=load_mnist(normalize=True,one_hot_label=True)   #normalize=True图像像素值正规化到0-1, one_hot_label为True的情况下,标签作为one-hot数组返回
#one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
print(x_train.shape)
print(t_train.shape)

#读取上面MINST数据后,训练数据有60000个,输入数据时784维(28x28)的图像数据,监督数据是10维的数据。因此上面的x_train,t_train的形状为(60000,784)和(60000,10)
#现在取出10笔数据
train_size=x_train.shape[0]                     #第一维度的大小(shape[0])60000
batch_size=10
batch_mask=np.random.choice(train_size,batch_size)   #np.random.choice()可以从指定的范围内选取想要的数据,比如这里是在[0,600000)内选取10个数据
x_batch=x_train[batch_mask]                           #重新赋值,取出10笔数据,之后只需要指定这些随机选出的索引,取出mini-batch,计算损失函数
t_batch=t_train[batch_mask]

'''
在先前的交叉熵误差中只计算了单个数据的损失函数,现在将所有mini-batch的训练数据的损失函数总数计算出来。
只要改良一下之前实现的对应单个数据的交叉熵误差就可以了。这里,我们来实现一个可以同时处理单个数据和批量数据(数据作为batch集中输人)两种情况的函数。

'''

def cross_entropy_error(y,t):
    if y.ndim==1:                                  #若输出为一维数据,如果不进行判断batch_size则为10
        y=y.reshape(1,y.size)                     #在不改变矩阵元素的情况下进行重新排列,重塑矩阵形状,这里是1xy.size
        t=t.reshape(1,t.size)
#这里,y是神经网络的输出,t是监督数据。y的维度为1时,即求单个数据的交叉熵误差时,需要改变数据的形状。
#并且,当输人为mini-batch时,要用batch的个数进行正规化,计算单个数据的平均交叉熵误差。
    batch_size=y.shape[0]                          #shape[0]即最外层(第一维度,也就是中括号由外到里,为1到N维)的个数,即batch_size(每一批的数据量)
    print("y.shape[0]:", y.shape[0])
    print("y.shape:", y.shape)
    return -np.sum(t*np.log(y+1e-7))/batch_size    #交叉熵误差

#测试一下交叉熵函数
t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
y = np.array([0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0])
test_result = cross_entropy_error(y,t)
print(test_result)

#上述方法,监督数据为one-hot形式,下面介绍标签形式(非one-hot)
def cross_entropy_error1(y,t):
    if y.ndim==1:
        t=t.reshape(1,t.size)
        y=y.reshape(1,y.size)
    batch_size=y.shape[0]
    print(y[np.arange(batch_size),t])
    return -np.sum(np.log(y[np.arange(batch_size),t]+1e-7))/batch_size
    #y[np.arange(batch_size),t],生成新的数组,例如batch_size为3,t为[1,2,3],那么结果为y[0,1],y[1,2],y[2,3]


运行结果为:在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-04-26 11:38:17  更:2022-04-26 11:39:21 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 16:42:47-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码