IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 6- 构建一个简单的分类网络 -> 正文阅读

[人工智能]6- 构建一个简单的分类网络

1. 说明

神经网络由对数据进行操作的层/模块组成。pytorch.nn的namespace命名空间提供了构建自己的神经网络所需的所有构建块。PyTorch中的每个模块都是nn.Module的子类。神经网络本身就是由其他模块(层)组成的模块。这种嵌套结构允许轻松地构建和管理复杂的体系结构

2. GPU设置

如果我们想将神经网络和参数放到GPU上进行训练,那么我们就需要去设置device

device = "cuda" if torch.cuda.is_available() else "cpu"

3. 自定义神经网络

如果我们要自定义一个神经网络类,那么我们必须满足至少三个条件

  • 自定的类需要继承自 nn.Module
  • 自定义类有初始化函数 __init__
  • 自定义类有前向传播函数forward
# 1.自定义类的父类为nn.Module
class My_Model(nn.Module):
	# 初始化函数__init__
	def __init__(self):
		super(My_Model, self).__init__()
		self.flatten = nn.Flatten()
		self.linear_relu_stack = nn.Sequential(
			nn.Linear(28 * 28, 512),
			nn.ReLU(),
			nn.Linear(512, 512),
			nn.ReLU(),
			nn.Linear(512, 10))
# 前向传播函数forward
	def forward(self, X):
		X = self.flatten(X)
		logist = self.linear_relu_stack(X)
		return logist
  • 如果我们需要在GPU上训练函数,我们需要将模型和参数都同时放到GPU上,如果模型在GPU上,参数在CPU上,那么就会报错
# 判断设备上是否有cuda的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# new_model = My_Model().to(device=device)
# 实例化神经网络,并把神经网络放到GPU上
new_model = My_Model().to(device)

print(f"device={device}")

# 定义输入矩阵x,将参数x放入到GPU上
x = torch.rand(3, 28, 28, device=device)

4. nn.Flatten

作用:将输入的张量从第1维到最后一维度进行合并展开,第0维不变
比如:如果输入是 x = torch.randn(3,4,5) ,通过nn.Flatten 后变成了 y.shap=(3,4*5)=(3,20)

import torch
from torch import nn

# 实例化一个nn.Flatten对象
flatten = nn.Flatten()
# 初始化一个输入张量
x = torch.ones(3, 4, 5)
# 将x导入到flatten中后输出得到y
y = flatten(x)
print(f"y.shape={y.shape}")
# y.shape=torch.Size([3, 20])

5. nn.Linear

作用: 就是一个MLP,将输入的张量的特征维(就是最后一维)进行改变
y = x A T + b y=xA^T+b y=xAT+b

import torch
from torch import nn

# input_features = 20 ,output_features = 8
mylinear = nn.Linear(20,8)
# input 张量中最后一维大小为20
input = torch.randn(3,4,20)
# nn.Linear是改变最后一个维度那么可以看出
# (3,4,20) -> (3,4,8)
output = mylinear(input)
print(f"output.shape={output.shape}")
# output.shape=torch.Size([3, 4, 8])

6. nn.Sequential

作用:是一个顺序容器,可以将对张量的操作进行顺序操作

7. nn.ReLU

作用: 一个激活函数,对于张量中负数值改变成0,保留正数值

Before ReLU: tensor([[-0.2237, -0.2367,  0.2977, -0.3347, -0.4724,  0.3709,  0.0294, -0.0807,
         -0.5721, -0.1723, -0.8035,  0.4663, -0.0803, -0.2520,  0.8864,  0.4762,
          0.2638, -0.1566,  0.0790, -0.0876],
        [-0.2885, -0.3101,  0.2298, -0.4918, -0.3310,  0.4374,  0.1665,  0.1405,
         -0.5300, -0.3482, -0.4831, -0.0948,  0.1129, -0.3147,  0.8067,  0.3847,
          0.2725, -0.0671,  0.4173, -0.3192],
        [-0.2258, -0.1209,  0.6989, -0.4547, -0.3201, -0.1266, -0.1083, -0.0766,
         -0.2590, -0.3851, -0.7130,  0.4853,  0.2001, -0.3398,  0.9755,  0.3800,
         -0.0782,  0.2659,  0.2886, -0.5325]], grad_fn=<AddmmBackward0>)


After ReLU: tensor([[0.0000, 0.0000, 0.2977, 0.0000, 0.0000, 0.3709, 0.0294, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4663, 0.0000, 0.0000, 0.8864, 0.4762, 0.2638, 0.0000,
         0.0790, 0.0000],
        [0.0000, 0.0000, 0.2298, 0.0000, 0.0000, 0.4374, 0.1665, 0.1405, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1129, 0.0000, 0.8067, 0.3847, 0.2725, 0.0000,
         0.4173, 0.0000],
        [0.0000, 0.0000, 0.6989, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.4853, 0.2001, 0.0000, 0.9755, 0.3800, 0.0000, 0.2659,
         0.2886, 0.0000]], grad_fn=<ReluBackward0>)

8. 小结

构建一个简单的分类网络代码

import torch
from torch import nn
from torchsummary import summary


class My_Model(nn.Module):
	def __init__(self):
		super(My_Model, self).__init__()
		self.flatten = nn.Flatten()
		self.linear_relu_stack = nn.Sequential(
			nn.Linear(28 * 28, 512),
			nn.ReLU(),
			nn.Linear(512, 512),
			nn.ReLU(),
			nn.Linear(512, 10))

	def forward(self, X):
		X = self.flatten(X)
		logist = self.linear_relu_stack(X)
		return logist


# 判断设备上是否有cuda的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# new_model = My_Model().to(device=device)
# 实例化神经网络,并把神经网络放到GPU上
new_model = My_Model().to(device)

print(f"device={device}")

# 定义输入矩阵x,将参数x放入到GPU上
x = torch.rand(3, 28, 28, device=device)
# 将输入矩阵进入网络后输出
y = new_model(x)
# 得到summary
summary(new_model, input_data=x, device=device)
pred_probab = nn.Softmax(dim=1)(y)
y_pred = pred_probab.argmax(1)
print(f"Predicted class:{y_pred}")
  • 结果:
device=cuda
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Flatten: 1-1                           [-1, 784]                 --
├─Sequential: 1-2                        [-1, 10]                  --
|    └─Linear: 2-1                       [-1, 512]                 401,920
|    └─ReLU: 2-2                         [-1, 512]                 --
|    └─Linear: 2-3                       [-1, 512]                 262,656
|    └─ReLU: 2-4                         [-1, 512]                 --
|    └─Linear: 2-5                       [-1, 10]                  5,130
==========================================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
Total mult-adds (M): 1.34
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.01
Params size (MB): 2.55
Estimated Total Size (MB): 2.57
==========================================================================================
Predicted class:tensor([8, 8, 0], device='cuda:0')
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-11 22:11:31  更:2022-03-11 22:15:44 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 14:50:02-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码