IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 梯度下降代码 -> 正文阅读

[人工智能]梯度下降代码

文章目录

参数

  • 学习率:alpha

  • 迭代次数:iteration

  • 梯度下降要乘以负号,若是梯度上升则为正号

公式

在这里插入图片描述

代码

from sklearn import datasets
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt1
import numpy as np
import random

# 加载数据
data = datasets.load_boston()
X, Y = data['data'][:,5],data['target']

# 参数设置
# 迭代次数
iteration = 20000
# 学习率
learningRate = 1e-3
# 随机指定k,b
k = random.randint(0, 50)
b = random.randint(-50, 50)
# 迭代范围内的最优参数k,b
best_k, best_b = None, None
# 'inf'表示正无穷大
min_loss = float('inf')

# k求偏导
def partial_k(y_true, y_predict, X):
    return -2*np.mean((np.array(y_true) - np.array(y_predict))*np.array(X))

# b求偏导
def partial_b(y_true, y_predict):
    return -2*np.mean(np.array(y_true) - np.array(y_predict))

# 求损失
def loss(y_true, y_predict):
    return np.mean((np.array(y_true) - np.array(y_predict))**2)
# 预测值
def predict(k, X, b):
    return k * X + b

plt.figure()
plt.scatter(X, Y, color='red', alpha=0.5)
# 随机指定的k,b绘制的图
plt.figure()
plt.scatter(X, Y, color='red', alpha=0.5)
plt.plot(X, predict(k, X, b), color='green', linewidth=3)
plt.show()

# 更新k和b
for i in range(iteration):
    # 预测值
    y_predict = predict(k, X, b)
    get_loss = loss(Y, y_predict)
    if(get_loss < min_loss):
        best_k = k
        best_b = b
        min_loss = get_loss
    k = k - partial_k(Y, y_predict, X) * learningRate
    b = b - partial_b(Y, y_predict) * learningRate

print("min_loss = ",min_loss)
# 梯度下降绘制的图
plt.figure()
plt.scatter(X, Y, color = "red")
plt.plot(X, predict(best_k, X, best_b), color = "blue")
plt.show()

在这里插入图片描述在这里插入图片描述

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-01-17 11:30:57  更:2022-01-17 11:33:33 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 17:01:05-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码