数据
data_loader = get_loader(args.dataset)
data_path = get_data_path(args.dataset)
tr_loader = data_loader(data_path, args.exp,is_transform=True, split='train', img_size=(args.img_rows, args.img_cols), augmentations=data_trans_train)
te_loader = data_loader(data_path, args.exp, is_transform=True, split='test', img_size=(args.img_rows, args.img_cols), augmentations=data_trans_test)
trainloader = DataLoader(tr_loader, batch_size=args.batch_size, shuffle=True, pin_memory=True,drop_last=True)
testloader = DataLoader(te_loader, batch_size=args.batch_size, shuffle=True, pin_memory=True)
超参数
beta = 1e-2
gamma = 1e-1
mix_ratio = 0.5
feature_channel = 6
main函数
if __name__ == "__main__":
out_root = Path(outimg_path)
if not out_root.is_dir():
os.makedirs(out_root, exist_ok=True)
best_acc = 0
for e in range(1, args.epochs + 1):
train(e)
best_acc = lalala(e, best_acc)
print('========epoch(%d)=========' % e)
print('best_acc:%02f' % (best_acc))
训练
tr_class = tr_loader.tr_class
te_class = te_loader.te_class
损失函数
论文中有四个损失函数:
def loss_match_func(feat_sem, temp_sem):
"""
均方差
:param feat_sem:torch.Size([16, 3, 64, 64])
:param temp_sem:torch.Size([16, 3, 64, 64])
:return: 标量
"""
MS = match_loss(feat_sem, temp_sem)
return MS
recon loss :保证语义特征有足够信息重建到模板
def loss_recon_func(recon_feat_sem, recon_temp_sem, template, recon_temp_sup=None, template_sup=None):
"""
重构损失
:param recon_feat_sem: torch.Size([16, 3, 64, 64]) 图片的语义特征
:param recon_temp_sem: torch.Size([16, 3, 64, 64]) 模板的语义特征
:param template: torch.Size([16, 3, 64, 64]) 模板
:param recon_temp_sup:
:param template_sup:
:return: 标量
"""
RE = recon_loss(recon_feat_sem, template) + recon_loss(recon_temp_sem, template)
if recon_temp_sup is not None:
recon_sup = recon_loss(recon_temp_sup, template_sup)
RE += recon_sup
return RE
def loss_class_func(out, target, out_sup=None, target_sup=None):
"""
交叉熵损失
:param out: torch.Size([16, 11])
:param target: torch.Size([16])
:param out_sup:
:param target_sup:
:return:
"""
CE = F.cross_entropy(out, target)
CE_sup = 0
if out_sup is not None:
CE_sup = F.cross_entropy(out_sup, target_sup)
return CE + CE_sup
def loss_illu_func(feat_illu, target):
"""
同一标签的图像,其照明特征越不同越好
:param feat_illu: torch.Size([16, 3, 64, 64])
:param target: 16
:return:
"""
pida_illu = PIDA_loss(feat_illu, target)
return -pida_illu
def PIDA_loss(feature, target):
"""
:param feature:
:param target:
:return:
"""
tg_unique = torch.unique(target)
pida_loss = 0
for tg in tg_unique:
feature_split = feature[target == tg, :, :, :]
mean_feature = torch.mean(feature_split, 0).unsqueeze(0)
mean_feature_rep = mean_feature.repeat(feature_split.shape[0], 1, 1, 1)
pida_loss += match_loss(feature_split, mean_feature_rep)
return pida_loss
模型
extract
- 提取图片的语义特征和光照特征
- 有6个没有
poling layer 的卷积层、 - 最后将输出6个通道的拆成两个3通道的,一个作为语义特征,一个作为光照特征
def extract(self, x, is_warping):
if is_warping and self.param1 is not None:
x = self.stn1(x)
h1 = self.leakyrelu(self.ex_bn1(self.ex1(self.ex_pd1(x))))
h2 = self.leakyrelu(self.ex_bn2(self.ex2(self.ex_pd2(h1))))
if is_warping and self.param2 is not None:
h2 = self.stn2(h2)
h3 = self.leakyrelu(self.ex_bn3(self.ex3(self.ex_pd3(h2))))
h4 = self.leakyrelu(self.ex_bn4(self.ex4(self.ex_pd4(h3))))
if is_warping and self.param3 is not None:
h4 = self.stn3(h4)
h5 = self.leakyrelu(self.ex_bn5(self.ex5(self.ex_pd5(h4))))
h6 = self.sigmoid(self.ex_bn6(self.ex6(self.ex_pd6(h5))))
feat_sem, feat_illu = torch.chunk(h6, 2, 1)
feat_sem_nowarp = feat_sem
if is_warping and self.param4 is not None:
feat_sem = self.stn4(feat_sem)
return feat_sem, feat_illu, feat_sem_nowarp
其中又用到了stn ,用于校正变形的语义特征
class stn(nn.Module):
def __init__(self, input_channels, input_size, params):
super(stn, self).__init__()
self.input_size = input_size
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv1 = nn.Sequential(
nn.ReplicationPad2d(2),
nn.Conv2d(input_channels, params[0], kernel_size=5, stride=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.ReplicationPad2d(2),
nn.Conv2d(params[0], params[1], kernel_size=5, stride=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv3 = nn.Sequential(
nn.ReplicationPad2d(2),
nn.Conv2d(params[1], params[2], kernel_size=3, stride=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
out_numel, out_size = convNoutput([self.conv1, self.conv2, self.conv3], input_size / 2)
self.fc = nn.Sequential(
View(),
nn.Linear(out_numel, params[3]),
nn.ReLU()
)
self.classifier = classifier = nn.Sequential(
View(),
nn.Linear(params[3], 6)
)
self.classifier[1].weight.data.fill_(0)
self.classifier[1].bias.data = torch.FloatTensor([1, 0, 0, 0, 1, 0])
def localization_network(self, x):
x = self.maxpool(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.fc(x)
x = self.classifier(x)
return x
def forward(self, x):
theta = self.localization_network(x)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
decode
将语义特征解码成模板
def decode(self, x):
h1 = self.leakyrelu(self.de_bn1(self.de1(self.de_pd1(x))))
h2 = self.leakyrelu(self.de_bn2(self.de2(self.de_pd2(h1))))
h3 = self.leakyrelu(self.de_bn3(self.de3(self.de_pd3(h2))))
h4 = self.leakyrelu(self.de_bn4(self.de4(self.de_pd4(h3))))
out = self.sigmoid(self.de5(self.de_pd5(h4)))
return out
classify
- 关于为什么有两个分类器
因为使用了全连接层,而用train 中的类有43个,test 中只有11个。全连接层参数是固定的,所以需要两个 - 为什么不直接用分割出来的语义特征训练分类器?
论文中提到
这个分类器是有43个训练类的
def classify(self, x):
h1 = self.pool2(self.leakyrelu(self.cls_bn1(self.cls1(x))))
h2 = self.leakyrelu(self.cls_bn2(self.cls2(h1)))
h3 = self.pool2(self.leakyrelu(self.cls_bn3(self.cls3(h2))))
h4 = self.leakyrelu(self.cls_bn4(self.cls4(h3)))
h5 = self.pool2(self.leakyrelu(self.cls_bn5(self.cls5(h4))))
h6 = self.leakyrelu(self.cls_bn6(self.cls6(h5)))
h7 = h6.view(-1,int(self.input_size / 8 * self.input_size / 8 * self.classify_chn[5]))
out = self.fc1(h7)
return out
classify2
这个分类器是有11个训练类的
def classify2(self, x):
h1 = self.pool2(self.leakyrelu(self.cls2_bn1(self.cls21(x))))
h2 = self.leakyrelu(self.cls2_bn2(self.cls22(h1)))
h3 = self.pool2(self.leakyrelu(self.cls2_bn3(self.cls23(h2))))
h4 = self.leakyrelu(self.cls2_bn4(self.cls24(h3)))
h5 = self.pool2(self.leakyrelu(self.cls2_bn5(self.cls25(h4))))
h6 = self.leakyrelu(self.cls2_bn6(self.cls26(h5)))
h7 = h6.view(-1,
int(self.input_size / 8 * self.input_size / 8 * self.classify_chn[5]))
out = self.fc2(h7)
return out
init_params
def init_params(self, net):
print('Loading the model from the file...')
net_dict = self.state_dict()
if isinstance(net, dict):
pre_dict = net
else:
pre_dict = net.state_dict()
pre_dict = {k: v for k, v in pre_dict.items() if (k in net_dict)}
net_dict.update(pre_dict)
self.load_state_dict(net_dict)
测试
测试大致和训练相同,分类用的classify2
|