IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PyTorch Swin-Transformer 各层特征可视化 -> 正文阅读

[人工智能]PyTorch Swin-Transformer 各层特征可视化

PyTorch相关开源库
https://gitee.com/hejuncheng1/pytorch-grad-cam

安装命令

pip install grad-cam

具体使用参考
Swin Transformer各层特征可视化_不高兴与没头脑Fire的博客-CSDN博客

提供示例

# dataloader.py
from torchvision import datasets, transforms
import os
import torch

input_size = 224

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    'test': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
}


def update(new_input_size):
    global input_size
    global data_transforms

    input_size = new_input_size

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'test': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    }


def dataloader(data_dir, batch_size, set_name, shuffle):
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [set_name]}
    num_workers = 1 if torch.cuda.is_available() else 0
    dataset_loaders = {
        x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=shuffle,
                                       num_workers=num_workers)
        for x in [set_name]}
    dataset_sizes = len(image_datasets[set_name])
    return dataset_loaders, dataset_sizes


if __name__ == '__main__':
    data_dir = ''
    dset_loaders, dset_sizes = dataloader(data_dir=data_dir, batch_size=16, set_name='train', shuffle=True)
    print(dset_loaders, dset_sizes)
# main.py
import cv2
import numpy as np
import torch
import torch.nn as nn
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image

import dataloader


def reshape_transform(tensor, height=12, width=12):
    result = tensor.reshape(tensor.size(0),
                            height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result


if __name__ == '__main__':
    net_name = 'swin_base_patch4_window12_384_22k'
    categories_size = 2
    model_ft = None

    if net_name == 'swin_base_patch4_window12_384_22k':
        from models import swintf

        model_ft = swintf.build_model('config/swin_base_patch4_window12_384_22k.yaml', use_checkpoint=True)
        model_ft.head = nn.Linear(1024, categories_size)
        dataloader.update(384)

    use_gpu = True if torch.cuda.is_available() else False
    if use_gpu:
        model_ft = model_ft.cuda()

    load_path = os.path.join('./save', net_name + '.pth')
    if os.path.exists(load_path):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        msg = model_ft.load_state_dict(torch.load(load_path, map_location=device))
    print('msg:', msg)

    model_ft.eval()

    target_layer = [model_ft.norm]

    target_category = 0
    image_path = ''
    image = Image.open(image_path)
    transformer = dataloader.data_transforms['test']
    image_ = transformer(image)
    inputs = image_.unsqueeze(0)

    cam = GradCAM(model=model_ft, target_layers=target_layer, use_cuda=False, reshape_transform=reshape_transform)
    cam.batch_size = 1
    grayscale_cam = cam(input_tensor=inputs, target_category=target_category, eigen_smooth=True,
                        aug_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    image = np.array(image.resize((384, 384))) / 255.0
    cam_image = show_cam_on_image(image, grayscale_cam)
    cv2.imwrite('cam.jpg', cam_image)
    print('OK')
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-09 18:22:35  更:2022-04-09 18:23:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 10:32:37-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码