2021SC@SDUSC
上次我们分析了model类中的类定义函数,了解了这个类各个参数的意思,接下来我们继续分析train.py。
m = m.to(args.device)
m就是一个model类,这是将m加载到设备device上。
if args.ckpt:
cpt = torch.load(args.ckpt)
m.load_state_dict(cpt)
starte = int(args.ckpt.split("/")[-1].split(".")[0])+1
args.lr = float(args.ckpt.split("-")[-1])
print('ckpt restored')
args是调用了pargs()函数的变量,是一个ArgumentParser对象,包含将命令行解析成 Python 数据类型所需的全部信息,在pargs.py中有体现。?
args = parser.parse_args()
parser.add_argument("-ckpt",default=None,type=str,help='load checkpoint')
对args进行检查point,若为真则加载到m的字典中,starte将checkpoint以“/”和“.”分割后数量+1。
parser.add_argument("-lr",default=0.1,type=float,help='learning rate')
这是有关args.lr的代码解释,在pargs.py文件中。
else:
with open(args.save+"/commandLineArgs.txt",'w') as f:
f.write("\n".join(sys.argv[1:]))
starte=0
若不为真,则新建一个文件,路径为保存后的文件路径+commandLineArgs.txt,向其中写入sys.argv[1:],也就是一个列表,里面都是用户输入的参数,然后将starte设为0。
o = torch.optim.SGD(m.parameters(),lr=args.lr, momentum=0.9)
我们先简单介绍一下torch.optim.SGD函数,它是实现随机梯度下降算法的函数,在普通的梯度下降法x+=v中,每次x的更新量v为v=dx*lr,其中dx为目标函数func(x)的对x的一阶导数。当使用“冲量”时,则把每次x的更新量v考虑为本次的梯度下降量-dx*lr与上次x的更新量v乘上一个介于[0,1]因子的momentum的和,即v=-dx*lr+v*momentum。这条语句是把这个模型m的参数,包括权重和偏置(是神经网络中的参数,也是SGD优化的重点)。lr为学习率,冲量设为0.9。
for e in range(starte,args.epochs):
print("epoch ",e,"lr",o.param_groups[0]['lr'])
train(m,o,ds,args)
starte为之前设置过的一个变量,已经被赋值,args.epochs默认为20,在pargs.py文件中。
parser.add_argument("-epochs",default=20,type=int)
在starte和args.epochs的范围里进行循环操作,每次循环之前打印第几次epoch和学习率,然后对m,o,ds,args调用train()函数。train()为train.py中的函数,也是这个测试的核心代码,我将在后面的博客里详细分析。
vloss = evaluate(m,ds,args)
evaluate()为评估函数,也是需要详细分析的核心代码。(该语句在上面的for循环中)
if args.lrwarm:
update_lr(o,args,e)
print("Saving model")
parser.add_argument("-lrwarm",action="store_true",help='use cycling learning rate')
关于parser.add_argument()记录一个特殊的情况:action。当运行的时候,如果不加入--lr_use,那么程序运行的时候,lr_use的值为default:False,如果加上了--lr_use,不需要指定True/False,那么程序运行的时候,lr_use的值为True。
|