IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于Pytorch对凸函数采用SGD算法优化实例(附源码) -> 正文阅读

[人工智能]基于Pytorch对凸函数采用SGD算法优化实例(附源码)

实例说明

基于Pytorch,手动编写SGD(随机梯度下降)方法,求-sin2(x)-sin2(y)的最小值,x∈[-2.5 , 2.5] , y∈[-2.5 , 2.5]。

画一下要拟合的函数图像

代码

import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错

def fix_fun(x,y):  #构建函数fix_fun = (x^2+y+10)^4+(x+y^2-5)^4
    return -numpy.sin(x)**2 - numpy.sin(y)**2

#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)
ax.plot_surface(x,y,z, cmap='rainbow')
matplotlib.pyplot.show()

画出来后,曲面长下面的样子↓
在这里插入图片描述

SGD算法构建思路

非常简单:用.backward()方法求出要优化的函数的梯度,按照SGD的定义找出最优解。
即:
θ=θ-η·▽J(θ)

运行结果

看看随机生成的几个优化的结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

源码

老规矩,细节的备注仍然写在代码里面。

import torch
from torch.autograd import Variable
import matplotlib.pyplot
from mpl_toolkits.mplot3d import Axes3D
import numpy
import os  #这句要加,要不画图会报错
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' #这句要加,要不画图会报错

def fix_fun_numpy(x,y):  #构建函数fix_fun(因为tensor数据类型和numpy画图不能混用,构建两个函数)
    return -numpy.sin(x)**2 - numpy.sin(y)**2
def fix_fun_tensor(x,y):  #构建函数fix_fun
    return -torch.sin(x)**2 - torch.sin(y)**2

#画一下函数图像
x = numpy.arange(-2.5,2.5,0.1)
y = numpy.arange(-2.5,2.5,0.1)
x, y = numpy.meshgrid(x, y)
z = fix_fun_numpy(x,y)
fig = matplotlib.pyplot.figure()
ax = Axes3D(fig)



#构建SGD算法
X = (torch.rand(2,dtype=float)-0.5)*5 #0~1分布改成-2.5~2.5分布,生成随机初始点(所以叫‘随机’梯度下降)
X_dots = []
Y_dots = []
Z_dots = []

for iter in range(50):
    X = Variable(X, requires_grad = True)  #需要求导的时候要加这一句
    Z = fix_fun_tensor(X[0],X[1])  #这一句一定要写在Variable后面
    Z.backward()  #求导(梯度)
    X = X - 0.1*X.grad #learning rate=0.1
    X_dots.append(X[0].detach().numpy())  #记录下每个学习过程点(为了后面画出学习的路径)
    Y_dots.append(X[1].detach().numpy())
    Z_dots.append(Z.detach().numpy())

ax.plot_wireframe(x, y, z)  #换成wireframe,因为用surface会和曲线重合,看不出来
ax.plot(X_dots,Y_dots,Z_dots,color='red')

matplotlib.pyplot.show()

后记

喜欢探索的同学还可以把上面的源码改一改,并试试Adagrad, Adam, Momentum这些方法,也对比下各自的优劣。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-24 18:28:52  更:2021-12-24 18:29:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 21:41:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码