IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 如何写神经网络 -> 正文阅读

[人工智能]如何写神经网络

数据预处理

读取文件

搭建模型

__init__函数和forward函数必须要有,class函数自动执行这两个函数,__init__函数中要定义整个网络中的所有网络层,前馈函数里要根据整个网络把__init__函数中定义的网络层连起来。

class Net(nn.Module):
	
	def __init__(self):
		...
		
	def forward(self,x):
		...
	
	def ...

搭建网络

定义所有要用到的网络层
比喻:制作积木块

  • 要注意维度问题,目前我是根据报错调试程序,程序说它需要什么类型的数据我就给他什么类型的数据
	def __init__(self):
        #使用super()方法调用基类的构造器,即nn.Module.__init__(self)
        super(CNN,self).__init__()
        # 1 input image channel ,6 output channels,5x5 square convolution kernel
        self.conv1=nn.Conv1d(500,12,2)
        # 6 input channl,16 output channels,5x5 square convolution kernel
        self.conv2=nn.Conv1d(6,16,4)
        # an affine operation:y=Wx+b
        self.fc1=nn.Linear(32,16)
        self.fc2=nn.Linear(16,8)
        self.fc3=nn.Linear(8,1)
        self.Sigmoid = nn.Sigmoid ()

前馈函数

将__init__函数中定义的各个网络层输入输出衔接起来
比喻:把积木块搭起来,用线串起来

  • 网络的输入x,通常数据类型为张量,如果要用cuda并行,记得把x放入gpu
  • 全连接层可以和激活函数写在一起,代码简洁并且不会忘写激活函数
	def forward(self,x):
        # x是网络的输入,然后将x前向传播,最后得到输出
        x=torch.Tensor(x).to(device)
        x=x.unsqueeze(0) # 输入,根据需要判断是否需要多加维度,比如LSTM、CNN需要三维数据,二维数据进来要多加一个维度
        x=self.Sigmoid(self.conv1(x)) # 一个一维卷积层,一个sigmoid层
        x=F.max_pool2d(x,(2,2)) # 定义了2x2的池化层
        x=self.Sigmoid(self.conv2(x))
        x=x.view(-1,self.num_flat_features(x)) 将x拉伸成一维张量,进行
        x=self.Sigmoid(self.fc1(x))
        x=self.Sigmoid(self.fc2(x))
        x=self.fc3(x)
        return x[0]
        
    # 根据需要,自己定义的方法。
    # 计算张量中共有多少个有效数据,计算数量,便于拉伸成一维张量
	def num_flat_features(self,x): 
        size=x.size()[1:] # 例如x的形状是[1, 16, 2],取出的size形状是[16, 2]
        num_features=1
        for s in size:# s的值分别为16,2,依次相乘,则为数据总数量
            num_features*=s
        return num_features # 返回的类型是int,一个数字

查看模型参数

print(Net().parameters())
params=list(Net().parameters())
print(len(params))
for i in range(len(params)):
    print(params[i].size())
#     print(params[i])

创建模型类的对象,定义损失函数和优化器

model = Net().to(device)
loss_function = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#建立优化器实例
print(model)

测试模型性能

这个函数是测试用来测试x_test y_test 数据

def eval(model): # 返回的是这10个 测试数据的平均loss
    test_epoch_loss=[]
    with torch.no_grad():
        optimizer.zero_grad()
        for i in range(0,10):
            y_pre = model(x_test[i])
            y_tru=torch.Tensor([y_test[i]]).to(device)
            test_loss = loss_function(y_pre,y_tru)
            test_epoch_loss.append(test_loss.item())
    return mean(test_epoch_loss)

定义数组分别存放训练和测试的损失,对整个训练过程重复epochs次迭代,每一次迭代中,需要逐个将训练集中的个体放入模型进行训练,

epochs =100
sum_train_epoch_loss=[] # 存储每个epoch 下 训练train数据的loss
sum_test_epoch_loss=[]  # 存储每个epoch 下 测试 test数据的loss
for epoch in range(epochs):
    epoch_loss=[]
    for i in range(0,85):
        #清除网络先前的梯度值
        optimizer.zero_grad()
        y_pred = model(x_train[i])
        y_true=torch.Tensor([y_train[i]]).to(device)
        #训练过程中,正向传播生成网络的输出,计算输出和实际值之间的损失值
        print("y_pred:",y_pred,",y_true:",y_true)
        single_loss = loss_function(y_pred,y_true)
        epoch_loss.append(single_loss.item())
        single_loss.backward()#调用backward()自动生成梯度
        optimizer.step()#使用optimizer.step()执行优化器,把梯度传播回每个网络 

    train_epoch_loss=mean(epoch_loss)   
    test_epoch_loss=eval(model)#返回的是这10个 测试数据的平均loss
    sum_train_epoch_loss.append(train_epoch_loss)
    sum_test_epoch_loss.append(test_epoch_loss)
    print("epoch:"+str(epoch)+"  train_epoch_loss: "+str(train_epoch_loss)+"  test_epoch_loss: "+str(test_epoch_loss))
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-29 10:15:46  更:2021-09-29 10:18:10 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/22 3:36:07-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码