IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch深度学习实践_p4_Backpropagation -> 正文阅读

[人工智能]pytorch深度学习实践_p4_Backpropagation

对于Backpropagation的一些理解

做gradient descent 的关键步骤是求loss对w(特征值,这里就用weight做示例)求导,当没有用到神经网络的时候对于其还是比较好求的,但是当嵌套了多层神经元后,其求导难度就大大增加了。这时backpropagation的作用就来了通过一次forward和一次backward以后就能求出。过程如下:
在这里插入图片描述

总结 :backpropagation的作用就是在神经网络中求出最后一层一层对于第一层input的导数,以便进行梯度下降

torch

  • w.required_grad()

作用告诉Tensor需要计算梯度,默认不需要

  • item()与data()的区别
  • item()返回的是一个具体的数值。
    注意:对于元素不止一个的tensor列表,使用item()会报错
  • .data返回原数据的深度拷贝是一个无梯度的tensor,=,即不会生成计算图

作业代码

这里的训练感觉有点问题,我预想的函数是y = x^2 + 2*x + 1,w2的初始值如果不是给的2的话预测偏差会特别大,但是w2的初始值不应该只是会影响梯度下降的速度吗?最终应该都会下降到2左右呀,不应该导致最终预测结果出现特别大的误差

import numpy as np
import matplotlib.pyplot as plt
import torch
x_data = [1.0, 2.0, 3.0, 4.0]
y_data = [4.0, 9.0, 16.0, 25.0]

w1 = torch.Tensor([1.0])  #将w变成张量,可以认为是一个高维数组
w1.requires_grad = True   #需要计算梯度
w2 = torch.Tensor([2.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True
def forward(x):
    return w1*x**2 + w2*x +b

def loss(x,y):
    y_pre = forward(x)
    return (y_pre-y)**2

print("Predict (before training)",4,forward(4).item()) #这里因为w是张量,所以w*x后,返回的也是个张量,所以需要item()取值

for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()
        print("\tgrad: ", x, y, w1.grad.item(), w2.grad.item(), b.grad.item())
        w1.data = w1.data - 0.01*w1.grad.data
        w2.data = w2.data - 0.01 * w2.grad.data
        b.data = b.data - 0.01 * b.grad.data

        w1.grad.data.zero_()  #这里需要把这个计算图中求得数据清0 防止对下一个循环的值产生影响,当然这个不是所有的都要清零,有一些案例需要用到前一计算图的结果
        w2.grad.data.zero_()
        b.grad.data.zero_()
    print("progress: ", epoch, l.item())
print("Predict (before training)",4,forward(4).item())

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-14 14:01:52  更:2021-08-14 14:04:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 1:06:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码