IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 【pytorch】简单的线性回归模型 -> 正文阅读

[Python知识库]【pytorch】简单的线性回归模型

前言

回归是一种能更加多个变量之间的关系进行建模的一种方法,其在机器学习中有着官方运用。线性回归是其中最最最最简单的一种,其假设自变量与因变量之间是线性关系。利用pytorch就可以简单地写出线性回归的代码。

线性回归

首先需要知道线性回归的基本假设:

  1. 自变量和因变量之间是线性关系,并且允许存在些噪声。
  2. 存在的噪声都是比较“正常”的,符号正态分布。
    因此可以用一个简单是式子来表示这个模型。
    y ^ = w T x + b \hat{y} = w^Tx+b y^?=wTx+b
    其中w,x均是列向量。
    对于这个简单的问题,可以用数学的方法,直接求出解析解,最常见的就说最小二乘法。
    但是,因为线性回归模型过于简单,才可以这样求出答案,对于其他的模型,是不可能的,因此在这里同样是使用随机梯度下降法。

数据集

这种简单的数据集,并不需要去哪里下载,直接自行生成即可。

def data_maker(w, b, n_size): # y=w*x+b,n个数据
    X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w)的参数
    y = torch.matmul(X , w) + b
    y = y + torch.normal(0, 0.01 , y.shape)
    return X , y.reshape((-1 , 1))

读取

不失一般性,一般都是读取一个batch的,为了方便,可以利用yield将其写成一个迭代器。

def data_iter(batch_size , x , y):
    n = len(x)
    index = list(range(n))
    random.shuffle(index)
    for i in range(0 , n , batch_size):
        batch_index = torch.tensor(index[i:min(i + batch_size,n)])
        yield x[batch_index], y[batch_index]

模型

为了更好的理解,此处使用自己定义的函数。
但是需要用到pytorch的自动求梯度。

def linreg(X,w,b):
    return torch.matmul(X , w) + b

def loss_function(Y , y):
    return (Y - y.reshape(Y.shape))**2/2

def SGD(params , learning_rate , batch_size):
    with torch.no_grad():
        for param in params:
            param -= learning_rate * param.grad / batch_size
            param.grad.zero_()

训练

这里和普通的模型差不多

#training
for epoch in range(num_epochs):
    for  X ,Y in data_iter(batch_size , x , y):
        l = loss(net(X , w,b), Y)
        l.sum().backward()
        SGD([w,b] , 0.001 , batch_size)
    with torch.no_grad():
        train_loss = loss(net(x , w , b), y)
        print(f'epoch{epoch+1} , loos {float(train_loss.mean())}')

完整代码

#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import random


# In[2]:


def data_maker(w, b, n_size): # y=w*x+b,n个数据
    X = torch.normal(0 , 1 , (n_size , len(w))) # n*len(w)的参数
    y = torch.matmul(X , w) + b
    y = y + torch.normal(0, 0.01 , y.shape)
    return X , y.reshape((-1 , 1))


# In[3]:


W = torch.tensor([4.0])
B = 1
x , y = data_maker(W, B, 1000)


# In[4]:


def data_iter(batch_size , x , y):
    n = len(x)
    index = list(range(n))
    random.shuffle(index)
    for i in range(0 , n , batch_size):
        batch_index = torch.tensor(index[i:min(i + batch_size,n)])
        yield x[batch_index], y[batch_index]


# In[5]:


def linreg(X,w,b):
    return torch.matmul(X , w) + b

def loss_function(Y , y):
    return (Y - y.reshape(Y.shape))**2/2

def SGD(params , learning_rate , batch_size):
    with torch.no_grad():
        for param in params:
            param -= learning_rate * param.grad / batch_size
            param.grad.zero_()


# In[6]:


batch_size = 5
num_epochs = 50
net = linreg
loss = loss_function
w = torch.normal(0,1 , size=(1,1) , requires_grad= True)
b = torch.normal(0,1 , size=(1,1) , requires_grad= True)


# In[7]:


#training
for epoch in range(num_epochs):
    for  X ,Y in data_iter(batch_size , x , y):
        l = loss(net(X , w,b), Y)
        l.sum().backward()
        SGD([w,b] , 0.001 , batch_size)
    with torch.no_grad():
        train_loss = loss(net(x , w , b), y)
        print(f'epoch{epoch+1} , loos {float(train_loss.mean())}')


# In[8]:


print(w , b)


  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-07-17 16:19:02  更:2022-07-17 16:22:12 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/27 3:19:19-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计