Spatial Transformer Networks(STN)-代码实现
-
pytorch为了方便实现STN,里面封装了affine_grid 和grid_sample 两个高级API。 -
STN的基本步骤是: -
L
o
c
a
l
i
s
a
t
i
o
n
??
n
e
t
\color{blue}Localisation\;net
Localisationnet(参数预测): Localisation net 模块通过
C
N
N
CNN
CNN提取图像的特征来预测变换矩阵
θ
\theta
θ -
G
r
i
d
??
g
e
n
e
r
a
t
o
r
\color{green}Grid\;generator
Gridgenerator(坐标映射): Grid generator 模块就是利用Localisation net 模块回归出来的
θ
\theta
θ参数来对图片中的位置进行变换,输入图片到输出图片之间的变换,需要特别注意的是这里指的是图片像素所对应的位置。 -
S
a
m
p
l
e
r
\color{gray}Sampler
Sampler(像素的采集): Sampler 就是用来解决Grid generator 模块变换出现小数位置的问题的。针对这种情况,STN采用的是双线性插值(Bilinear Interpolation ),下面我们来介绍一下这个算法
1. STN层的实现
from torchvision import transforms
import torch.nn.functional as F
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
img = Image.open("img/test.jpg")
img_tensor = transforms.ToTensor()(img)
theta = torch.tensor([[1,0,0.1],[0,1,0.2]],
dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0),
img_tensor.unsqueeze(0).size(),align_corners=True)
output = F.grid_sample(img_tensor.unsqueeze(0),
grid,align_corners=True)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.array(img))
plt.title("original image")
plt.subplot(1,2,2)
plt.imshow(output[0].numpy().transpose(1,2,0))
plt.title("stn transform image")
plt.show()
2. STN+CNN
当输入图片通过STN模块之后获得变换后的图片,然后我们再将变换后的图片输入到
C
N
N
CNN
CNN网络中,通过损失函数计算
l
o
s
s
loss
loss,然后计算梯度更新
θ
\theta
θ参数,最终STN模块会学习到如何矫正图片。
2.1 参数设置
config.py
import argparse
def parse_args():
parse = argparse.ArgumentParser("config stn args")
parse.add_argument("--lr",default=0.01,
type=float,help="learning rate")
parse.add_argument("--epoch_nums",default=20,
type=int,help="iterated epochs")
parse.add_argument("--use_stn",default=True,
type=bool,help="whether to use STN module")
parse.add_argument("--batch_size",default=64,
type=int,help="batch size")
parse.add_argument("--use_eval",default=True,
type=bool,help="whether to evaluate")
parse.add_argument("--use_visual",default=True,
type=bool,help="visual STN transform image")
parse.add_argument("--use_gpu",default=True,
type=bool,help="whether to use GPU")
parse.add_argument("--show_net_construct",default=False,
type=bool,help="print net construct info")
return parse.parse_args()
2.2 加载数据
DataLoader.py
import torch
from torchvision import datasets,transforms
import numpy as np
def get_dataloader(batch_size):
device = torch.device("cuda" if torch.cuda.is_available()
else "cpu")
train_dataloader = torch.utils.data.DataLoader(
datasets.MNIST(root="D:\PyCharm\PyCharm_Project\STN", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(
datasets.MNIST(root="D:\PyCharm\PyCharm_Project\STN", train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batch_size, shuffle=True)
return train_dataloader,test_dataloader
def tensor_to_array(img_tensor):
img_array = img_tensor.numpy().transpose((1,2,0))
mean = np.array([0.485,0.456,0.406])
std = np.array([0.229,0.224,0.225])
img_array = std * img_array + mean
img = np.clip(img_array,0,1)
return img
2.3 定义网络
Net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class STN_Net(nn.Module):
def __init__(self,use_stn=True):
super(STN_Net, self).__init__()
self.conv1 = nn.Conv2d(1,10,kernel_size=5)
self.conv2 = nn.Conv2d(10,20,kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320,50)
self.fc2 = nn.Linear(50,10)
self._use_stn = use_stn
self.localization = nn.Sequential(
nn.Conv2d(1,8,kernel_size=7),
nn.MaxPool2d(2,stride=2),
nn.ReLU(True),
nn.Conv2d(8,10,kernel_size=5),
nn.MaxPool2d(2,stride=2),
nn.ReLU(True)
)
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3,32),
nn.ReLU(True),
nn.Linear(32,2*3)
)
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1,0,0,0,1,0]
,dtype=torch.float))
def stn(self,x):
xs = self.localization(x)
xs = xs.view(-1,10*3*3)
theta = self.fc_loc(xs)
theta = theta.view(-1,2,3)
grid = F.affine_grid(theta,x.size())
x = F.grid_sample(x,grid)
return x
def forward(self,x):
if self._use_stn:
x = self.stn(x)
x = F.relu(F.max_pool2d(self.conv1(x),2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
x = x.view(-1,320)
x = F.relu(self.fc1(x))
x = F.dropout(x,training=self.training)
x = self.fc2(x)
return F.log_softmax(x,dim=1)
2.4 训练模型
train.py
import torch,torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
def train(net,epoch_nums,lr,train_dataloader,per_batch,device):
net.train()
optimizer = optim.SGD(net.parameters(),lr=lr)
for epoch in range(epoch_nums):
for batch_idx,(data,label) in enumerate(train_dataloader):
data,label = data.to(device),label.to(device)
optimizer.zero_grad()
pred = net(data)
loss = F.nll_loss(pred,label)
loss.backward()
optimizer.step()
if batch_idx % per_batch == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_dataloader.dataset),
100. * batch_idx / len(train_dataloader), loss.item()))
2.5 评估模型
evaluate.py
import torch
import torch.nn.functional as F
def evaluate(net,test_dataloader,device):
with torch.no_grad():
net.eval()
eval_loss = 0
eval_acc = 0
for data,label in test_dataloader:
data,label = data.to(device),label.to(device)
pred = net(data)
eval_loss += F.nll_loss(pred,label,
size_average=False).item()
pred_label = pred.max(1,keepdim=True)[1]
eval_acc += pred_label.eq(label.view_as(pred_label)
).sum().item()
eval_loss /= len(test_dataloader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(eval_loss, eval_acc, len(test_dataloader.dataset),
100. * eval_acc / len(test_dataloader.dataset)))
2.6 可视化
Visualize.py
import torch,torchvision
import matplotlib.pyplot as plt
from DataLoader import tensor_to_array
def visualize_stn(net,dataloader,device):
with torch.no_grad():
data = next(iter(dataloader))[0].to(device)
input_tensor = data.cpu()
t_input_tensor = net.stn(data).cpu()
in_grid = tensor_to_array(torchvision.utils.make_grid(
input_tensor))
out_grid = tensor_to_array(torchvision.utils.make_grid(
t_input_tensor))
f,axarr = plt.subplots(1,2)
axarr[0].imshow(in_grid)
axarr[0].set_title("input images")
axarr[1].imshow(out_grid)
axarr[1].set_title("stn transformed images")
plt.show()
2.7 主函数
MAIN.py
import torch
from Net import STN_Net
from Visualize import visualize_stn
from train import train
from config import parse_args
from DataLoader import get_dataloader
from evaluate import evaluate
if __name__ == "__main__":
args = parse_args()
if args.use_gpu and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
train_loader,test_loader = get_dataloader(args.batch_size)
net = STN_Net(args.use_stn).to(device)
train(net,args.epoch_nums,args.lr,train_loader
,args.batch_size,device)
if args.use_eval:
evaluate(net,test_loader,device)
if args.use_visual:
visualize_stn(net,test_loader,device)
Test set: Average loss: 0.0423, Accuracy: 9868/10000 (99%)
参考
- 通俗易懂的Spatial Transformer Networks(STN)(一)
- 通俗易懂的Spatial Transformer Networks(STN)(二)
- SPATIAL TRANSFORMER NETWORKS TUTORIAL
|