? ? ? ? 最近在看去雾方面的论文时,学习了2021年CVPR去雾赛道冠军DWT,这篇论文引入了一种使用二维离散小波变换的新型去雾网络DW-GAN,使用双分支网络来解决雾度分布复杂和过拟合问题,在 DWT 分支中利用小波变换,在knowledge adaptation分支中使用Res2Net,最后使用基于补丁的判别器来减少恢复图像的伪影。
? ? ? ? 作者提供的源码中只有测试代码,,以下给出我自己理解的训练代码train.py
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from test_dataset import dehaze_test_dataset
from model import fusion_net,Discriminator
from torchvision.utils import save_image as imwrite
import os
import time
import re
from train_dataset import dehaze_train_dataset
from torchvision.models import vgg16
from utils_test import to_psnr, to_ssim_skimage
import torch.nn.functional as F
from perceptual import LossNetwork
from pytorch_msssim import msssim
# --- Parse hyper-parameters train --- #
parser = argparse.ArgumentParser(description='Dehaze Training')
parser.add_argument('--train_dir', type=str, default='./train/')
parser.add_argument('--output_dir', type=str, default='./trained_result/')
parser.add_argument('-train_batch_size', help='Set the training batch size', default=2, type=int)
parser.add_argument('-learning_rate', type=float, default=1e-4)
parser.add_argument('-train_epoch', help='Set the training epoch', default=500, type=int)
parser.add_argument('--model_save_dir', type=str, default='./output_result')
# --- Parse hyper-parameters test --- #
parser.add_argument('--test_dir', type=str, default='./test_image/')
parser.add_argument('-test_batch_size', help='Set the testing batch size', default=1, type=int)
parser.add_argument('--vgg_model', default='', type=str, help='load trained model or not')
args = parser.parse_args()
# --- train --- #
learning_rate = args.learning_rate
train_batch_size = args.train_batch_size
train_epoch = args.train_epoch
train_dir = args.train_dir
output_dir = args.output_dir
train_dataset = dehaze_train_dataset(train_dir)
# --- test --- #
test_dir = args.test_dir
test_dataset = dehaze_test_dataset(test_dir)
test_batch_size = args.test_batch_size
# --- output picture and check point --- #
if not os.path.exists(args.model_save_dir):
os.makedirs(args.model_save_dir)
# --- Gpu device --- #
#将模型和数据迁移到gpu上
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# --- Define the network --- #
net = fusion_net()
DNet = Discriminator()
print('# Discriminator parameters:', sum(param.numel() for param in DNet.parameters()))
# --- Multi-GPU --- #
net = net.to(device)
net = nn.DataParallel(net)
DNet = DNet.to(device)
DNet= nn.DataParallel(DNet)
#net.load_state_dict(torch.load('./weights/dehaze.pkl'))
# --- Build optimizer --- #
G_optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(G_optimizer, milestones=[5000, 7000, 8000], gamma=0.5)
D_optim = torch.optim.Adam(DNet.parameters(), lr=0.0001)
scheduler_D = torch.optim.lr_scheduler.MultiStepLR(D_optim, milestones=[5000,7000,8000], gamma=0.5)
# --- Load training data --- #
train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True)
# --- Load testing data --- #
test_loader = DataLoader(dataset=train_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)
# --- Define the perceptual loss network --- #
#预训练的VGG16作为损失网络来测量感知相似度
vgg_model = vgg16(pretrained=True)
vgg_model = vgg_model.features[:16].to(device)
for param in vgg_model.parameters():
param.requires_grad = False
loss_network = LossNetwork(vgg_model)
loss_network.eval()
msssim_loss = msssim
# --- Load the network weight --- #
try:
net.load_state_dict(torch.load(os.path.join(args.teacher_model, 'epoch100000.pkl')))
print('--- weight loaded ---')
except:
print('--- no weight loaded ---')
# --- Strat training --- #
iteration = 0
for epoch in range(train_epoch):
start_time = time.time()
scheduler_G.step()
scheduler_D.step()
net.train()
DNet.train()
print(epoch)
for batch_idx, (name, hazy, clean) in enumerate(train_loader):
iteration += 1
hazy = hazy.to(device)
frame_out = net(hazy)
clean = clean.to(device)
DNet.zero_grad()
real_out = DNet(clean).mean()
fake_out = DNet(frame_out).mean()
D_loss = 1 - real_out + fake_out
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
D_loss.backward(retain_graph=True)
net.zero_grad()
adversarial_loss = torch.mean(1 - fake_out)
smooth_loss_l1 = F.smooth_l1_loss(frame_out, clean)
perceptual_loss = loss_network(frame_out, clean)
msssim_loss_ = -msssim_loss(frame_out, clean, normalize=True)
total_loss = smooth_loss_l1 + 0.01 * perceptual_loss + 0.0005 * adversarial_loss + 0.5 * msssim_loss_
total_loss.backward()
D_optim.step()
G_optimizer.step()
if epoch % 5 == 0:
print('we are testing on epoch: ' + str(epoch))
with torch.no_grad():
psnr_list = []
ssim_list = []
recon_psnr_list = []
recon_ssim_list = []
net.eval()
for batch_idx, (name, hazy, clean) in enumerate(test_loader):
clean = clean.to(device)
hazy = hazy.to(device)
frame_out = net(hazy)
if not os.path.exists(output_dir + '/'):
os.makedirs(output_dir + '/')
name = re.findall("\d+", str(name))
#imwrite(frame_out, output_dir + '/' + str(name[0]) + '.png', range=(0, 1)) # 保存图像
psnr_list.extend(to_psnr(frame_out, clean))
ssim_list.extend(to_ssim_skimage(frame_out, clean))
avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
print(epoch, 'dehazed', avr_psnr, avr_ssim)
frame_debug = torch.cat((frame_out, clean), dim=0)
torch.save(net.state_dict(), os.path.join(args.model_save_dir, 'epoch' + str(epoch) + '.pkl'))
|