IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch 笔记:手动实现AR (auto regressive) -> 正文阅读

[人工智能]pytorch 笔记:手动实现AR (auto regressive)

1 导入库& 数据说明

import  numpy as np
import torch
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

x=np.array([[  1,   2],
       [  3,   4],
       [  5,   6],
       [  7,  10],
       [ 13,  18],
       [ 23,  30],
       [ 37,  50],
       [ 63,  86],
       [109, 146],
       [183, 246]])
xx=torch.tensor(x,dtype=torch.float32,requires_grad=True)
xx

我们有这样一组数据x,每一行代表一个时刻的二维数据,我们希望的是X[t]=2 \times X[t-3] + X[t-1]

2 模型

pytorch笔记 pytorch模型中的parameter与buffer_UQI-LIUWJ的博客-CSDN博客

import torch

class Net(torch.nn.Module):
    def __init__(self,lag_list):
        super(Net,self).__init__()
        
        self.lag_list=lag_list
        #时滞集

        self.l=len(self.lag_list)
        #时滞集大小

        self.param=torch.nn.Parameter(torch.zeros(self.l,2))
        #自回归系数,一个time lag 一行自回归系数
        #自回归系数放入Parameter里面,那么也会随着神经网路梯度下降更新
        
    #定义前向传播方式
    def forward(self,x):
        max_len=max(self.lag_list)
        #时滞集的最大元素,从这个元素之后的一个元素开始,考虑自回归预测和本身的差距

        l=x.shape[0]
        if(type(l)==type(1)):
            pass
        else:
            l=int(l)
        #这里为什么要这么分类一下呢?其实这也是我的一个疑问。
        '''
        直接用pytorch的话,l是一个int类型的数字
        但是如果使用了tensorboardX之后,l是一个Tensor,那么就得用类型转换回int
        '''

        tensor_len=l-max_len
        #表示需要进行(自回归预测-数据本身)比较的长度
        #也就是max_len后一位开始到最后的这一部分的长度
        
        lst_lag=[]
        #不同时滞集对应的输入数据的部分
        for i in self.lag_list:
            lst_lag.append(x[max_len-i:max_len-i+tensor_len].clone())
        #这里为什么要通过这种clone+放入列表的方式呢?
        #因为如果你直接在原始数据上进行切片操作【比如x[i]=x[i-1]+x[i-2]】这种
        #它会因为切片和inplace的问题,而无法传递梯度,最终会报错

        ret_tmp_origin=x[max_len:max_len+tensor_len].clone()
        #原始数据需要比较的一部分
        
        ret_var=self.param[0]*lst_lag[0]
        for i in range(1,self.l):
            ret_var=ret_var+self.param[i]*lst_lag[i]
        #自回归预测的部分

        return(ret_var-ret_tmp_origin)
        #自回归预测的部分和原始部分的差距

3 模型训练& tensorboardX可视化

net=Net([3,1])

from tensorboardX import SummaryWriter
 
writer = SummaryWriter('runs/AR_pytorch')
#提供一个路径,将使用该路径来保存日志
 
writer.add_graph(net,xx)
#添加net的计算图


 
#使用Adams进行梯度下降,设置学习率为0.02
optimizer=torch.optim.Adam(net.parameters(),lr=0.02)
 

#开始训练
for epoch in range(1000):
    prediction=net(xx)
    loss=torch.linalg.norm(prediction)
    #每次损失函数都是自回归预测和原始数据之间的L2范数
    
    writer.add_scalar('loss_',loss,epoch)
    #添加loss的summary信息
    
    optimizer.zero_grad()
    #清空上一步残余的参数更新值
    
    loss.backward()
    #误差反向传播,计算参数更新值
    
    optimizer.step()
    #将参数更新值施加到net的parameters上
    
    for i in net.parameters():
        writer.add_scalar('param-0_time_lag-3',i[0][0].item(),epoch)
        writer.add_scalar('param-1_time_lag-3',i[0][1].item(),epoch)
        writer.add_scalar('param-0_time_lag-1',i[1][0].item(),epoch)
        writer.add_scalar('param-1_time_lag-1',i[1][1].item(),epoch)
    
    

在tensorboardX上,有:

3.1 损失函数图像

?3.2 自回归系数图像

可以看到,最后和我们目标的[2,1]很接近了?

?

?

?3.3? TensorboardX 计算图

整体如下,左边是自回归的计算路径,右边是原始数据(截取max_len之后的部分)的计算路径

3.3.1 右侧计算路径

ret_tmp_origin=x[max_len:max_len+tensor_len].clone()
#原始数据需要比较的一部分

?

?3.3.2 左侧计算路径

左右路(下)? x[2:9] (X[t-1])

左左路(下)?? x[0:7] (X[t-3])

?左中路

?就是不同时滞对应的切片*对应的param值

最后这是一个减法,自回归得到的和原始数据相减

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-28 11:16:07  更:2021-11-28 11:16:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 2:51:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码