2021SC@SDUSC
我们继续看main(args)的那个循环。
for e in range(starte,args.epochs):
print("epoch ",e,"lr",o.param_groups[0]['lr'])
train(m,o,ds,args)
vloss = evaluate(m,ds,args)
if args.lrwarm:
update_lr(o,args,e)
print("Saving model")
torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
if vloss > lastloss:
if args.lrdecay:
print("decay lr")
o.param_groups[0]['lr'] *= 0.5
lastloss = vloss
上次我们分析了train函数,接下来我们分析evaluate函数。这个函数主要对数据集进行评估操作。
def evaluate(m,ds,args):
print("Evaluating",end="\t")
m.eval()
loss = 0
ex = 0
m、ds、args参数和train的函数参数相同,此处不再重复。然后对m变字符串。
for b in ds.val_iter:
b = ds.fixBatch(b)
p,z,planlogits = m(b)
p = p[:,:-1,:]
tgt = b.tgt[:,1:].contiguous().view(-1).to(args.device)
l = F.nll_loss(p.contiguous().view(-1,p.size(2)),tgt,ignore_index=1)
每个循环都对数据集进行fixBatch函数的调用,进行数据集的批处理,返回的data赋给b。然后对b进行模型化,.contiguous()和.nll_loss()函数在之前的博客已经详细分析过。
if ex == 0:
g = p[0].max(1)[1]
print(ds.reverse(g,b.rawent[0]))
loss += l.item() * len(b.tgt)
ex += len(b.tgt)
loss = loss/ex
print("VAL LOSS: ",loss,end="\t")
if loss < 100: print(" PPL: ",exp(loss))
m.train()
return loss
for循环结束以后,对m调用train()函数,返回损失值。
回到main函数,继续往下看循环体内的操作。
if args.lrwarm:
update_lr(o,args,e)
print("Saving model")
parser.add_argument("-lrwarm",action="store_true",help='use cycling learning rate')
这里对参数做一个判断,如果使用循环学习率,就调用update_lr函数,更新学习率。
def update_lr(o,args,epoch):
if epoch%args.lrstep == 0:
o.param_groups[0]['lr'] = args.lrhigh
else:
o.param_groups[0]['lr'] -= args.lrchange
parser.add_argument("-epochs",default=20,type=int)
parser.add_argument("-lrstep",default=4, type=int,help='steps in cycle')
parser.add_argument("-lrhigh",default=0.5,type=float,help="high learning rate for cycling")
args.lrchange = (args.lrhigh - args.lr)/args.lrstep
epoch的默认值为20,lrstep是学习率在循环中的次数。lrhigh是高学习率。这里在训练中动态调整学习率,o.param_groups是长度为6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]这6个参数。
torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
if vloss > lastloss:
if args.lrdecay:
print("decay lr")
o.param_groups[0]['lr'] *= 0.5
lastloss = vloss
parser.add_argument("-lrdecay",action="store_true",help="use learning rate decay")
后面保存学习率。如果用衰减学习率,就将学习率*0.5。最终的损失数等于vloss参数。
至此,train.py的核心代码和关键代码已经全部分析完毕,接下来我将和小组成员进行交流,并进行其他关键代码的分析。
|