IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 数据结构与算法 -> 线性回归--手动 -> 正文阅读

[数据结构与算法]线性回归--手动

解析解求解主要需要推导出 W计算公式

y = w ? x + b = W ? X y = w * x + b = W*X y=w?x+b=W?X
为例,选取均方误差为损失函数:
l o s s = 1 2 n ? ( y ? y p r e d ) 2 loss = \frac{1}{2n} * (y - y_{pred})^2 loss=2n1??(y?ypred?)2
直接贴出推导结果(我推的太不好了):
W = ( X @ X T ) ? 1 @ X @ Y W = (X@ X^T) ^{-1}@ X @ Y W=(X@XT)?1@X@Y

代码:

import numpy as np
import matplotlib.pyplot as plt
def make_fake_data():
    # y = 3*x + 1
    x = np.random.rand(20) * 10
    y = 3 * x + (1 + np.random.randn(20)*3)
    return x, y

np.random.seed(10)
x, y = make_fake_data()
x_b = np.ones(20)
x = np.vstack((x, x_b))
w = np.linalg.pinv(x @ np.transpose(x)) @ x @ y
print(w)
y_pred = w @ x
plt.scatter(x[0, :], y)
plt.plot(x[0, :], y_pred)
plt.show()

结果:
[3.1382164 0.78223531]
在这里插入图片描述

梯度下降求解以
y = w ? x + b = W ? X y = w * x + b = W*X y=w?x+b=W?X
为例,选取均方误差为损失函数:
l o s s = 1 2 n ? ( y ? y p r e d ) 2 loss = \frac{1}{2n} * (y - y_{pred})^2 loss=2n1??(y?ypred?)2
梯度计算:
? = 1 n ? ( y ? W ? X ) ? X T \nabla = \frac{1}{n} * (y - W*X) *X^T ?=n1??(y?W?X)?XT
利用梯度更新参数,注意梯度方向,系数更新公式:
W = W + a ? ? W = W + a * \nabla W=W+a??
a为学习率,不要太大,不然结果会乱跳(不收敛)

代码:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def make_fake_data():
    # y = 3*x + 1
    x = np.random.rand(20) * 10
    y = 3 * x + (1 + np.random.randn(20)*3)
    return x, y
def monitor_mse(y, y_pred):
    Loss = ((y - y_pred) @ np.transpose(y - y_pred)) / len(y)
    return Loss
np.random.seed(10)


x, y = make_fake_data()
x_b = np.ones(20)
x = np.vstack((x, x_b))

k = 2001
a = 0.01  # 学习率小点好,大了会乱跑
A = np.random.rand(2)

for i in range(1, k):
    y_pred = np.transpose(A) @ x
    A = A + a * ((y - y_pred) / len(y)) @ np.transpose(x)


    if i % 500 == 0:
        print(f"第 {i} 次 A:", A)
        print(f"第 {i} 次 A:", monitor_mse(y, y_pred))



plt.scatter(x[0, :], y)
plt.plot(x[0, :], y_pred)

plt.show()

结果:
第 500 次 A: [3.11735897 0.92539329]
第 500 次 A: 11.390507360402756
第 1000 次 A: [3.13215241 0.82385636]
第 1000 次 A: 11.385755910470582
第 1500 次 A: [3.13645339 0.79433601]
第 1500 次 A: 11.385354285166539
第 2000 次 A: [3.13770383 0.7857534 ]
第 2000 次 A: 11.385320337027098
在这里插入图片描述

  数据结构与算法 最新文章
【力扣106】 从中序与后续遍历序列构造二叉
leetcode 322 零钱兑换
哈希的应用:海量数据处理
动态规划|最短Hamilton路径
华为机试_HJ41 称砝码【中等】【menset】【
【C与数据结构】——寒假提高每日练习Day1
基础算法——堆排序
2023王道数据结构线性表--单链表课后习题部
LeetCode 之 反转链表的一部分
【题解】lintcode必刷50题<有效的括号序列
上一篇文章      下一篇文章      查看所有文章
加:2022-03-16 22:43:30  更:2022-03-16 22:51:34 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 1:59:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码