IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 游戏开发 -> pytorch earlystopping使用 -> 正文阅读

[游戏开发]pytorch earlystopping使用

作者:token comment

earlystopping实现
https://github.com/Bjarten/early-stopping-pytorch

自己实现LinearRegression

# 我通过一个线性回归(或者一个自编码器)来联系使用earlystop
import torch
import numpy as np
import torch.nn as nn

#X1=np.linspace(-10,10,100)
#X2=np.linspace(-20,20,100)
#X3=np.linspace(0,40,100)
# 这样的样本不是独立的,不要这样构造数据,否则求解的系数会不对的
n_sample=300
X=np.random.randn(n_sample,3)
noise=np.random.randn(n_sample)

Y=X[:,0]*0.5+X[:,1]*(-0.7)+X[:,2]*4-0.3+0.2*noise
from sklearn.linear_model import LinearRegression
model=LinearRegression()
model.fit(X,Y)
print(model.coef_)
print(model.intercept_)
# 首先验证数据的合理性

#train_X=np.c_[X,noise]
train_X=X.astype("float32").copy()
train_Y=Y.astype("float32").copy()
train_Y=train_Y.reshape((len(train_Y),1)) ## 这时候的维度必须是n*1的二维矩阵,不能是n的一维矩阵
#print(train_X.shape)
from torch.utils.data import TensorDataset
trainset=TensorDataset(torch.from_numpy(train_X),torch.from_numpy(train_Y))
dataloader=torch.utils.data.DataLoader(trainset,batch_size=20,num_workers=0,shuffle=True)

class LR_torch(nn.Module):
    def __init__(self,input_shape=3,output_shape=1):
        super(LR_torch,self).__init__()# 先要对父类初始化
        print("初始化")
        self.input_shape=input_shape
        self.output_shape=output_shape
        self.linear=nn.Linear(self.input_shape,self.output_shape) # 默认的bias就是True
        torch.nn.init.xavier_uniform_(self.linear.weight)
    def forward(self,x):
        return self.linear(x)
        #self.act=,不需要激活函数
model=LR_torch()
optimizer=torch.optim.SGD(params=model.parameters(),lr=0.01)
criterion=nn.MSELoss()
from torch.autograd import Variable
num_epochs=100
import matplotlib.pyplot as plt

train_loss=[]
for i in range(1,num_epochs+1):
    temp_loss=[]
    for ids,(input,target) in enumerate(dataloader):
        input=Variable(input)
        target=Variable(target)
        y_pred=model(input)
        loss=criterion(y_pred,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        temp_loss.append(loss.item())
        #print("loss={}".format(loss.item()))
    train_loss.append(np.mean(temp_loss))

plt.plot(range(len(train_loss)),train_loss)
## 打印权重
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
        
print(model.linear.weight)
print(model.linear.bias)

# for name in model.state_dict():
#     print(name)

结果如下
在这里插入图片描述
可以看到,优化的结果没有什么问题

那么我现在的想法很简单,用earlystopping 让模型在100个epoch之内停止下来

加入earlystopping

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Apr  2 06:17:43 2022

@author: xiaokangyu
"""

# 我通过一个线性回归(或者一个自编码器)来学习使用earlystop
import torch
import numpy as np
import torch.nn as nn
import sys
sys.path.append("/Users/xiaokangyu/Desktop/DESC_INSCT/学习/early-stopping-pytorch-master/")
from pytorchtools import EarlyStopping ## 导入earlytopping

#X1=np.linspace(-10,10,100)
#X2=np.linspace(-20,20,100)
#X3=np.linspace(0,40,100)
# 这样的样本不是独立的
n_sample=300
X=np.random.randn(n_sample,3)
noise=np.random.randn(n_sample)

Y=X[:,0]*0.5+X[:,1]*(-0.7)+X[:,2]*4-0.3+0.2*noise
from sklearn.linear_model import LinearRegression
model=LinearRegression()
model.fit(X,Y)
print(model.coef_)
print(model.intercept_)


#train_X=np.c_[X,noise]
train_X=X.astype("float32").copy()
train_Y=Y.astype("float32").copy()
train_Y=train_Y.reshape((len(train_Y),1))
#print(train_X.shape)
from torch.utils.data import TensorDataset
trainset=TensorDataset(torch.from_numpy(train_X),torch.from_numpy(train_Y))
dataloader=torch.utils.data.DataLoader(trainset,batch_size=20,num_workers=0,shuffle=True)

class LR_torch(nn.Module):
    def __init__(self,input_shape=3,output_shape=1):
        super(LR_torch,self).__init__()# 先要对父类初始化
        print("初始化")
        self.input_shape=input_shape
        self.output_shape=output_shape
        self.linear=nn.Linear(self.input_shape,self.output_shape) # 默认的bias就是True
        torch.nn.init.xavier_uniform_(self.linear.weight)
    def forward(self,x):
        return self.linear(x)
        #self.act=,不需要激活函数
model=LR_torch()
optimizer=torch.optim.SGD(params=model.parameters(),lr=0.01)
criterion=nn.MSELoss()
from torch.autograd import Variable
num_epochs=100 # 这个地方有30多个epoch就停了
import matplotlib.pyplot as plt

train_loss=[]

# initialize the early_stopping object
# 定义这个earlystop 
patience=3
early_stopping = EarlyStopping(patience=patience, delta=1e-4,verbose=True) # 定义early_stopping
for i in range(1,num_epochs+1):
    temp_loss=[]
    for ids,(input,target) in enumerate(dataloader):
        input=Variable(input)
        target=Variable(target)
        y_pred=model(input)
        loss=criterion(y_pred,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        temp_loss.append(loss.item())
        #print("loss={}".format(loss.item()))
    train_loss.append(np.mean(temp_loss))
    print("epoch={},train_loss={}".format(i,train_loss[-1]))
    early_stopping(train_loss[-1], model)# 每次传入当前epoch的loss值,不是从epoch=0开始的全部的loss 列表
        
    if early_stopping.early_stop:
        print("Early stopping")
        break

    # load the last checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))# 如果提前退出,那么需要加载之前保存的模型参数

plt.plot(range(len(train_loss)),train_loss)
## 打印权重
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name)
        
print(model.linear.weight)
print(model.linear.bias)

# for name in model.state_dict():
#     print(name)

最终结果如下
在这里插入图片描述
在这里插入图片描述
可以开电脑,模型在20多个epoch时候集直接停止了,达到需求

  游戏开发 最新文章
6、英飞凌-AURIX-TC3XX: PWM实验之使用 GT
泛型自动装箱
CubeMax添加Rtthread操作系统 组件STM32F10
python多线程编程:如何优雅地关闭线程
数据类型隐式转换导致的阻塞
WebAPi实现多文件上传,并附带参数
from origin ‘null‘ has been blocked by
UE4 蓝图调用C++函数(附带项目工程)
Unity学习笔记(一)结构体的简单理解与应用
【Memory As a Programming Concept in C a
上一篇文章      下一篇文章      查看所有文章
加:2022-04-04 12:42:53  更:2022-04-04 12:44:27 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/17 15:21:29-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码