IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> DW-GAN训练代码 -> 正文阅读

[人工智能]DW-GAN训练代码

? ? ? ? 最近在看去雾方面的论文时,学习了2021年CVPR去雾赛道冠军DWT,这篇论文引入了一种使用二维离散小波变换的新型去雾网络DW-GAN,使用双分支网络来解决雾度分布复杂和过拟合问题,在 DWT 分支中利用小波变换,在knowledge adaptation分支中使用Res2Net,最后使用基于补丁的判别器来减少恢复图像的伪影。

? ? ? ? 作者提供的源码中只有测试代码,,以下给出我自己理解的训练代码train.py

import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from test_dataset import dehaze_test_dataset
from model import fusion_net,Discriminator
from torchvision.utils import save_image as imwrite
import os
import time
import re
from train_dataset import dehaze_train_dataset
from torchvision.models import vgg16
from utils_test import to_psnr, to_ssim_skimage
import torch.nn.functional as F
from perceptual import LossNetwork
from pytorch_msssim import msssim

# --- Parse hyper-parameters train --- #
parser = argparse.ArgumentParser(description='Dehaze Training')
parser.add_argument('--train_dir', type=str, default='./train/')
parser.add_argument('--output_dir', type=str, default='./trained_result/')
parser.add_argument('-train_batch_size', help='Set the training batch size', default=2, type=int)
parser.add_argument('-learning_rate', type=float, default=1e-4)
parser.add_argument('-train_epoch', help='Set the training epoch', default=500, type=int)
parser.add_argument('--model_save_dir', type=str, default='./output_result')

# --- Parse hyper-parameters test --- #
parser.add_argument('--test_dir', type=str, default='./test_image/')
parser.add_argument('-test_batch_size', help='Set the testing batch size', default=1, type=int)
parser.add_argument('--vgg_model', default='', type=str, help='load trained model or not')
args = parser.parse_args()

# --- train --- #
learning_rate = args.learning_rate
train_batch_size = args.train_batch_size
train_epoch = args.train_epoch
train_dir = args.train_dir
output_dir = args.output_dir
train_dataset = dehaze_train_dataset(train_dir)

# --- test --- #
test_dir = args.test_dir
test_dataset = dehaze_test_dataset(test_dir)
test_batch_size = args.test_batch_size

# --- output picture and check point --- #
if not os.path.exists(args.model_save_dir):
    os.makedirs(args.model_save_dir)

# --- Gpu device --- #
#将模型和数据迁移到gpu上
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
net = fusion_net()
DNet = Discriminator()
print('# Discriminator parameters:', sum(param.numel() for param in DNet.parameters()))

# --- Multi-GPU --- #
net = net.to(device)
net = nn.DataParallel(net)
DNet = DNet.to(device)
DNet= nn.DataParallel(DNet)

#net.load_state_dict(torch.load('./weights/dehaze.pkl'))

# --- Build optimizer --- #
G_optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(G_optimizer, milestones=[5000, 7000, 8000], gamma=0.5)
D_optim = torch.optim.Adam(DNet.parameters(), lr=0.0001)
scheduler_D = torch.optim.lr_scheduler.MultiStepLR(D_optim, milestones=[5000,7000,8000], gamma=0.5)

# --- Load training data --- #
train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True)

# --- Load testing data --- #
test_loader = DataLoader(dataset=train_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)


# --- Define the perceptual loss network --- #
#预训练的VGG16作为损失网络来测量感知相似度
vgg_model = vgg16(pretrained=True)
vgg_model = vgg_model.features[:16].to(device)
for param in vgg_model.parameters():
    param.requires_grad = False
loss_network = LossNetwork(vgg_model)
loss_network.eval()

msssim_loss = msssim

# --- Load the network weight --- #
try:
    net.load_state_dict(torch.load(os.path.join(args.teacher_model, 'epoch100000.pkl')))
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

# --- Strat training --- #
iteration = 0
for epoch in range(train_epoch):
    start_time = time.time()
    scheduler_G.step()
    scheduler_D.step()
    net.train()
    DNet.train()
    print(epoch)
    for batch_idx, (name, hazy, clean) in enumerate(train_loader):
        iteration += 1
        hazy = hazy.to(device)
        frame_out = net(hazy)
        clean = clean.to(device)

        DNet.zero_grad()
        real_out = DNet(clean).mean()
        fake_out = DNet(frame_out).mean()
        D_loss = 1 - real_out + fake_out

        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()

        D_loss.backward(retain_graph=True)
        net.zero_grad()
        adversarial_loss = torch.mean(1 - fake_out)
        smooth_loss_l1 = F.smooth_l1_loss(frame_out, clean)
        perceptual_loss = loss_network(frame_out, clean)
        msssim_loss_ = -msssim_loss(frame_out, clean, normalize=True)
        total_loss = smooth_loss_l1 + 0.01 * perceptual_loss + 0.0005 * adversarial_loss + 0.5 * msssim_loss_

        total_loss.backward()
        D_optim.step()
        G_optimizer.step()

    if epoch % 5 == 0:
        print('we are testing on epoch: ' + str(epoch))
        with torch.no_grad():
            psnr_list = []
            ssim_list = []
            recon_psnr_list = []
            recon_ssim_list = []
            net.eval()
            for batch_idx, (name, hazy, clean) in enumerate(test_loader):
                clean = clean.to(device)
                hazy = hazy.to(device)
                frame_out = net(hazy)
                if not os.path.exists(output_dir + '/'):
                    os.makedirs(output_dir + '/')
                name = re.findall("\d+", str(name))
                #imwrite(frame_out, output_dir + '/' + str(name[0]) + '.png', range=(0, 1))  # 保存图像
                psnr_list.extend(to_psnr(frame_out, clean))
                ssim_list.extend(to_ssim_skimage(frame_out, clean))

            avr_psnr = sum(psnr_list) / len(psnr_list)
            avr_ssim = sum(ssim_list) / len(ssim_list)
            print(epoch, 'dehazed', avr_psnr, avr_ssim)
            frame_debug = torch.cat((frame_out, clean), dim=0)
            torch.save(net.state_dict(), os.path.join(args.model_save_dir, 'epoch' + str(epoch) + '.pkl'))
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-08-19 19:04:57  更:2022-08-19 19:07:06 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/25 23:44:49-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码