IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> resMLP的PyTorch代码和分析 -> 正文阅读

[人工智能]resMLP的PyTorch代码和分析

?1、导入相关功能包

import torch
from torch import nn
from torchsummary import summary
from tensorboardX import SummaryWriter

2、定义Affine模块

Aff_{\alpha ,\beta }(x)=Diag(\alpha )x+\beta

????????初始化α=1, β=0。

class Affine(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, 1, channel))
        self.beta = nn.Parameter(torch.zeros(1, 1, channel))

    def forward(self, x):
        return x * self.alpha + self.beta

3、定义PreAffinePostLayerScale模块

????????参考:Going deeper with Image Transformers

class PreAffinePostLayerScale(nn.Module):  # https://arxiv.org/abs/2103.17239
    def __init__(self, dim, depth, fn):
        super().__init__()
        if depth <= 18:
            init_eps = 0.1
        elif depth > 18 and depth <= 24:
            init_eps = 1e-5
        else:
            init_eps = 1e-6

        scale = torch.zeros(1, 1, dim).fill_(init_eps)
        self.scale = nn.Parameter(scale)
        self.affine = Affine(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.affine(x)) * self.scale + x

????????就是实现了以下部分:

?4、构建resMLP网络

class ResMLP(nn.Module):
    def __init__(self, dim=128, image_size=256, patch_size=16, expansion_factor=4, depth=2, class_num=1000):
        super().__init__()
        self.flatten = Rearange(image_size, patch_size)  # Rearange(image_size=256, patch_size=16)
        num_patches = (image_size // patch_size) ** 2
        wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn) # 封装
        self.embedding = nn.Linear((patch_size ** 2) * 3, dim)
        self.mlp = nn.Sequential()
        for i in range(depth):
            self.mlp.add_module('fc1_%d' % i, wrapper(i, nn.Conv1d(patch_size ** 2, patch_size ** 2, 1)))
            # nn.Conv1d(patch_size ** 2 = 256, patch_size ** 2 = 256, 1)
            self.mlp.add_module('fc1_%d' % i, wrapper(i, nn.Sequential(
                nn.Linear(dim, dim * expansion_factor),
                nn.GELU(),
                nn.Linear(dim * expansion_factor, dim)
            )))

        self.aff = Affine(dim)
        self.classifier = nn.Linear(dim, class_num)
        self.softmax = nn.Softmax(1)

    def forward(self, x):
        y = self.flatten(x)
        y = self.embedding(y)
        # a = y.shape
        y = self.mlp(y)
        y = self.aff(y)
        y = torch.mean(y, dim=1)  # bs,dim
        out = self.softmax(self.classifier(y))
        return out

?????????在第一层学习到的交互模式非常类似于一个小型卷积过滤器:

self.mlp.add_module('fc1_%d' % i, wrapper(i, nn.Conv1d(patch_size ** 2, patch_size ** 2, 1)))

网络结构如下:

5、测试网络

# 测试resMLP
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # model = gMLPForImageClassification(image_size=256, patch_size=16, in_channels=3, num_classes=1000, d_model=256,
    #                                    d_ffn=512, seq_len=256, num_layers=1, ).to(device)
    model = ResMLP(dim=128, image_size=256, patch_size=16, class_num=1000).to(device)
    summary(model, (3, 256, 256))  # [2,3,256,256]

    inputs = torch.Tensor(2, 3, 256, 256)
    inputs = inputs.to(device)
    print(inputs.shape)
    # 将model保存为graph
    with SummaryWriter(log_dir='logs', comment='model') as w:
        w.add_graph(model, (inputs,))
        print("success")

????????得到输出如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
          Rearange-1             [-1, 256, 768]               0
            Linear-2             [-1, 256, 128]          98,432
            Affine-3             [-1, 256, 128]               0
            Linear-4             [-1, 256, 512]          66,048
              GELU-5             [-1, 256, 512]               0
            Linear-6             [-1, 256, 128]          65,664
PreAffinePostLayerScale-7             [-1, 256, 128]               0
            Affine-8             [-1, 256, 128]               0
            Linear-9                 [-1, 1000]         129,000
          Softmax-10                 [-1, 1000]               0
================================================================
Total params: 359,144
Trainable params: 359,144
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 4.77
Params size (MB): 1.37
Estimated Total Size (MB): 6.89
----------------------------------------------------------------
torch.Size([2, 3, 256, 256])
success

????????通过TensorboardX查看网络具体结构:

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-19 11:52:51  更:2021-10-19 11:53:30 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 10:50:08-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码