IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> windows11下运行swin-transformer算法 -> 正文阅读

[人工智能]windows11下运行swin-transformer算法

一、背景

我们希望使用swin-transformer算法实现物体的分类。?

swin-transformer的github地址为:https://github.com/microsoft/Swin-Transformer

本文参考:Swin Transformer实战:使用 Swin Transformer实现图像分类-阿里云开发者社区

二、环境配置

(1)配置要求

windows版本:windows11

pytorch版本:1.7.1

python版本:3.7.3(至少要大于3.6.2,因为pytorch1.7.1的python最低要求是3.6.2)

cuda版本:11.0(pytorch1.7.1在windows11下使用,最少需要cuda11.0)

以上配置为我试验swin-transformer运行的相对比较低的配置要求。

(2)安装方式

torch1.7.1的安装命令:pip install torch==1.7.1 -f https://download.pytorch.org/whl/torch_stable.html

torchvision的安装命令:pip install torchvision==0.8.2 -f https://download.pytorch.org/whl/torch_stable.html

cuda中官网下载toolkit即可,windows下cuda11.0.X版本可选择windows10对应的版本。

三、训练集构造

swin-transformer默认读取imagenet格式的数据集。

数据集的目录结构如下:

四、修改源码

1、修改config.py文件

_C.DATA.DATA_PATH = 'D:\\temp\\pic_ai\\swin_transformer_test'

_C.MODEL.NUM_CLASSES = 2

_C.DATA.NUM_WORKERS = 0

_C.DATA.PIN_MEMORY = False? ?

2、修改build.py文件

将nb_classes =1000改为nb_classes = config.MODEL.NUM_CLASSES,如下所示:

将部分_pil_interp修改为str_to_pil_interp,如下图所示:

?3、修改utils.py文件

由于类别默认是1000,所以加载模型的时候会出现类别对不上的问题,所以需要修改load_checkpoint方法。在加载预训练模型之前增加修改预训练模型的方法:

if checkpoint['model']['head.weight'].shape[0] == 1000:
    checkpoint['model']['head.weight'] = torch.nn.Parameter(
        torch.nn.init.xavier_uniform(torch.empty(config.MODEL.NUM_CLASSES, 768)))
    checkpoint['model']['head.bias'] = torch.nn.Parameter(torch.randn(config.MODELNUM_CLASSES))
msg = model.load_state_dict(checkpoint['model'], strict=False)

?

?4、修改main.py文件

(1)将如下代码注释:

(2)将torch.distributed.init_process_group修改为:

torch.distributed.init_process_group('gloo', init_method='file://tmp/somefile', rank=0, world_size=1)

该函数只有在pytorch1.7.1以上才支持。

5、修改lr_scheduler.py文件

将如下代码注释掉

五、运行训练命令:

python.exe D:/workspace/transformer/Swin-Transformer/main.py --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --local_rank 0 --batch-size 2

执行后显示如下:

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-07-03 10:48:33  更:2022-07-03 10:51:49 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 0:55:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码