IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> DeamNet||训练代码学习train.py注释与解析 -> 正文阅读

[Python知识库]DeamNet||训练代码学习train.py注释与解析

目录

1. 导入各种库,设置运行环境

2. 训练设置,各种参数

3.测试


?1. 导入各种库,设置运行环境

from __future__ import print_function
import os
import time
import socket
import pandas as pd
import argparse
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from DeamNet import Deam
from data import get_training_set, get_eval_set
from skimage.measure.simple_metrics import compare_psnr
from real_dataloader import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# 修改:增加此行.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 判断CUDA是否能使用,不可以就使用CPU

2. 训练设置,各种参数

parser.add_argument() 用法
argparse 模块是 Python 内置的一个用于命令行选项与参数解析的模块
使用argparse 的第一步是创建一个 ArgumentParser 对象

name or flags - 选项字符串的名字或者列表

default - 不指定参数时的默认值。
type - 命令行参数应该被转换成的类型。

help - 参数的帮助信息,当指定为 argparse.SUPPRESS 时表示不显示该参数的帮助信息.

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--upscale_factor', type=int, default=1, help="super resolution upscale factor")
parser.add_argument('--batchSize', type=int, default=1, help='training batch size')   # 修改
parser.add_argument('--nEpochs', type=int, default=10, help='number of epochs to train for')  # 修改 nEpochs=2000
parser.add_argument('--start_iter', type=int, default=1, help='starting epoch')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate. default=0.0001')
parser.add_argument('--data_augmentation', type=bool, default=True, help='if adopt augmentation when training')
parser.add_argument('--hr_train_dataset', type=str, default='DIV2K_train_HR', help='the training dataset')
parser.add_argument('--Ispretrained', type=bool, default=True, help='If load checkpoint model')
parser.add_argument('--pretrained_sr', default='noise25.pth', help='sr pretrained base model')
parser.add_argument('--pretrained', default='./Deam_models', help='Location to load checkpoint models')
parser.add_argument("--noiseL", type=float, default=25, help='noise level')
parser.add_argument('--save_folder', default='./checkpoint/', help='Location to save checkpoint models')
parser.add_argument('--statistics', default='./statistics/', help='Location to save statistics')

# Testing settings
parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size, default=1')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--test_dataset', type=str, default='Set12', help='the testing dataset')
parser.add_argument("--val_noiseL", type=float, default=25, help='noise level used on validation set')

# Global settings
parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
parser.add_argument('--gpus', default=1, type=int, help='number of gpus')
parser.add_argument('--data_dir', type=str, default='D:/Papers to read/2022.07/DeamNet-main/DeamNet-main/Dataset', help='the dataset dir')
parser.add_argument('--model_type', type=str, default='Deam', help='the name of model')
parser.add_argument('--patch_size', type=int, default=128, help='Size of cropped HR image')
parser.add_argument('--Isreal', default=False, help='If training/testing on RGB images')

# ArgumentParser 通过 parse_args() 方法解析参数
opt = parser.parse_args()
gpus_list = range(opt.gpus)
hostname = str(socket.gethostname()) # Socket 获取本地主机名 socket.gethostname()函数
# str函数是Python的内置函数,它将参数转换成字符串类型
cudnn.benchmark = True
print(opt)

cudnn.benchmark = True

设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。
如果网络的输入数据维度或类型上变化不大,也就是每次训练的图像尺寸都是一样的时候,设置 torch.backends.cudnn.benchmark = True 可以增加运行效率;
如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率?

enumerate函数作用于一个可遍历对象(如列表,元组或者字符串),将其组合为一个索引序列,可以同时获得索引和值(常使用在for循环中)?

enumerate(sequence,[start = 0])

?sequence:可遍历对象

start:下标的起始位置,可以指定遍历的起始位置

Variable(变量)??在 Torch 中的 Variable 就是一个存放会变化的值的地理位置。里面的值会不停的变化。

torch.normal?(means, std, out=None) 返回一个张量,包含从给定参数means,std的离散正态分布中抽取随机数。

用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()optimizer.step()三个函数

def train(epoch):
    epoch_loss = 0
    model.train()
    for iteration, batch in enumerate(training_data_loader, 1):
        target = Variable(batch)
        # torch.FloatTensor 类型转换, 将list ,numpy转化为tensor 
        noise = torch.FloatTensor(target.size()).normal_(mean=0, std=opt.val_noiseL / 255.)
        input = target + noise

        input = input.cuda()
        target = target.cuda()

        model.zero_grad()
        optimizer.zero_grad()   # 先将梯度归零
        t0 = time.time()

        prediction = model(input)

        # Corresponds to the Optimized Scheme
        loss = criterion(prediction, target)/(input.size()[0]*2)

        t1 = time.time()
        epoch_loss += loss.data
        loss.backward()     # 然后反向传播计算得到每个参数的梯度值
        optimizer.step()    # 最后通过梯度下降执行一步参数更新

        if (iteration+1) % 50 == 0:  #判断迭代次数取余等于0
            model.eval()    #评估模式而非训练模式。等同于 self.train(False)
            SC = 'net_epoch_' + str(epoch) + '_' + str(iteration + 1) + '.pth'
             # model.state_dict() 浅拷贝:拷贝最外层的数值和指针,不拷贝更深层次的对象
            torch.save(model.state_dict(), os.path.join(opt.save_folder, SC)) 
              # os.path.join 路径拼接
            model.train()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Timer: {:.4f} sec.".format(epoch, iteration, len(training_data_loader), loss.data, (t1 - t0)))
    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
    return (PSNR / Img.shape[0])

?3.测试

def test(testing_data_loader):
    psnr_test= 0
    model.eval()
    for batch in testing_data_loader:
        target = Variable(batch[0])
        noise = torch.FloatTensor(target.size()).normal_(mean=0, std=opt.noiseL / 255.)
        input = target + noise

        input = input.cuda()
        target = target.cuda()
        with torch.no_grad():
            prediction = model(input)
# torch.clamp(input, min, max, out=None).将input张量的值压缩到区间 [min,max],结果返回到一个新张量。
            prediction = torch.clamp(prediction, 0., 1.)
        psnr_test += batch_PSNR(prediction, target, 1.)
    print("===> Avg. PSNR: {:.4f} dB".format(psnr_test / len(testing_data_loader)))
    return psnr_test / len(testing_data_loader)


def print_network(net):
    num_params = 0
    for param in net.parameters():
     # param.numel() 统计模型参数量
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


def checkpoint(epoch,psnr):
    model_out_path = opt.save_folder+hostname+opt.model_type+"_psnr_{}".format(psnr)+"_epoch_{}.pth".format(epoch)
    torch.save(model.state_dict(), model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))

if __name__ == '__main__':
    print('===> Loading datasets')

    if opt.Isreal:
        train_set = Dataset_h5_real(src_path=os.path.join(opt.data_dir, 'train', 'train.h5'), patch_size=opt.patch_size, train=True)
        training_data_loader = DataLoader(dataset=train_set, batch_size=opt.batch_size, shuffle=True, num_workers=4,
                                drop_last=True)
        test_set = Dataset_h5_real(src_path=os.path.join(opt.data_dir, 'test', 'val.h5'), patch_size=opt.patch_size, train=False)
        testing_data_loader = DataLoader(dataset=test_set, batch_size=opt.testBatchSize, shuffle=False, num_workers=0, drop_last=True)
    else:
        train_set = get_training_set(os.path.join(opt.data_dir, 'train'), opt.hr_train_dataset, opt.upscale_factor,
                                     opt.patch_size, opt.data_augmentation)
        training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)

        test_set = get_eval_set(os.path.join(opt.data_dir, 'test', opt.test_dataset), opt.upscale_factor)
        testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

    print('===> Building model ', opt.model_type)
    model = Deam(opt.Isreal)

    model = torch.nn.DataParallel(model, device_ids=gpus_list)
    criterion = nn.MSELoss()

    print('---------- Networks architecture -------------')
    print_network(model)
    print('----------------------------------------------')

    if opt.Ispretrained:
        model_name = os.path.join(opt.pretrained, opt.pretrained_sr)
    # load_state_dict 模型加载
        model.load_state_dict(torch.load(model_name, map_location=lambda storage, loc: storage))
        print(model_name + ' model is loaded.')

  # optim.Adam() 实现Adam算法
    optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8)

    PSNR = []
    # 修改
    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

    for epoch in range(opt.start_iter, opt.nEpochs + 1):
        train(epoch)
        psnr = test(testing_data_loader)
        PSNR.append(psnr)
        data_frame = pd.DataFrame(
            data={'epoch': epoch, 'PSNR': PSNR}, index=range(1, epoch+1)
        )
        data_frame.to_csv(os.path.join(opt.statistics, 'training_logs.csv'), index_label='index')
        # learning rate is decayed by a factor of 10 every half of total epochs


        if (epoch + 1) % (opt.nEpochs / 2) == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 10.0
            # print('Learning rate decay: lr={}'.param_group['lr'])
            print('Learning rate decay: lr={}')


torch.optim.Adam 方法的使用和参数的解释

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-09-15 01:58:19  更:2022-09-15 01:58:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/26 13:45:12-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计