PyTorch相关开源库 https://gitee.com/hejuncheng1/pytorch-grad-cam
安装命令
pip install grad-cam
具体使用参考 Swin Transformer各层特征可视化_不高兴与没头脑Fire的博客-CSDN博客
提供示例
from torchvision import datasets, transforms
import os
import torch
input_size = 224
data_transforms = {
'train': transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
'val': transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
'test': transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
}
def update(new_input_size):
global input_size
global data_transforms
input_size = new_input_size
data_transforms = {
'train': transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.RandomResizedCrop(size=input_size, scale=(0.7, 1)),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
'val': transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
'test': transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
}
def dataloader(data_dir, batch_size, set_name, shuffle):
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [set_name]}
num_workers = 1 if torch.cuda.is_available() else 0
dataset_loaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers)
for x in [set_name]}
dataset_sizes = len(image_datasets[set_name])
return dataset_loaders, dataset_sizes
if __name__ == '__main__':
data_dir = ''
dset_loaders, dset_sizes = dataloader(data_dir=data_dir, batch_size=16, set_name='train', shuffle=True)
print(dset_loaders, dset_sizes)
import cv2
import numpy as np
import torch
import torch.nn as nn
import os
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import dataloader
def reshape_transform(tensor, height=12, width=12):
result = tensor.reshape(tensor.size(0),
height, width, tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
if __name__ == '__main__':
net_name = 'swin_base_patch4_window12_384_22k'
categories_size = 2
model_ft = None
if net_name == 'swin_base_patch4_window12_384_22k':
from models import swintf
model_ft = swintf.build_model('config/swin_base_patch4_window12_384_22k.yaml', use_checkpoint=True)
model_ft.head = nn.Linear(1024, categories_size)
dataloader.update(384)
use_gpu = True if torch.cuda.is_available() else False
if use_gpu:
model_ft = model_ft.cuda()
load_path = os.path.join('./save', net_name + '.pth')
if os.path.exists(load_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
msg = model_ft.load_state_dict(torch.load(load_path, map_location=device))
print('msg:', msg)
model_ft.eval()
target_layer = [model_ft.norm]
target_category = 0
image_path = ''
image = Image.open(image_path)
transformer = dataloader.data_transforms['test']
image_ = transformer(image)
inputs = image_.unsqueeze(0)
cam = GradCAM(model=model_ft, target_layers=target_layer, use_cuda=False, reshape_transform=reshape_transform)
cam.batch_size = 1
grayscale_cam = cam(input_tensor=inputs, target_category=target_category, eigen_smooth=True,
aug_smooth=True)
grayscale_cam = grayscale_cam[0, :]
image = np.array(image.resize((384, 384))) / 255.0
cam_image = show_cam_on_image(image, grayscale_cam)
cv2.imwrite('cam.jpg', cam_image)
print('OK')
|