IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 知识图到文本的生成——拾 -> 正文阅读

[人工智能]知识图到文本的生成——拾

2021SC@SDUSC

我们继续看main(args)的那个循环。

  for e in range(starte,args.epochs):
    print("epoch ",e,"lr",o.param_groups[0]['lr'])
    train(m,o,ds,args)
    vloss = evaluate(m,ds,args)
    if args.lrwarm:
      update_lr(o,args,e)
    print("Saving model")
    torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
    if vloss > lastloss:
      if args.lrdecay:
        print("decay lr")
        o.param_groups[0]['lr'] *= 0.5
    lastloss = vloss

上次我们分析了train函数,接下来我们分析evaluate函数。这个函数主要对数据集进行评估操作。

def evaluate(m,ds,args):
  print("Evaluating",end="\t")
  m.eval()
  loss = 0
  ex = 0

m、ds、args参数和train的函数参数相同,此处不再重复。然后对m变字符串。

  for b in ds.val_iter:
    b = ds.fixBatch(b)
    p,z,planlogits = m(b)
    p = p[:,:-1,:]
    tgt = b.tgt[:,1:].contiguous().view(-1).to(args.device)
    l = F.nll_loss(p.contiguous().view(-1,p.size(2)),tgt,ignore_index=1)

每个循环都对数据集进行fixBatch函数的调用,进行数据集的批处理,返回的data赋给b。然后对b进行模型化,.contiguous()和.nll_loss()函数在之前的博客已经详细分析过。

    if ex == 0:
      g = p[0].max(1)[1]
      print(ds.reverse(g,b.rawent[0]))
    loss += l.item() * len(b.tgt)
    ex += len(b.tgt)
  loss = loss/ex
  print("VAL LOSS: ",loss,end="\t")
  if loss < 100: print(" PPL: ",exp(loss))
  m.train()
  return loss

for循环结束以后,对m调用train()函数,返回损失值。

回到main函数,继续往下看循环体内的操作。

    if args.lrwarm:
      update_lr(o,args,e)
    print("Saving model")
parser.add_argument("-lrwarm",action="store_true",help='use cycling learning rate')

这里对参数做一个判断,如果使用循环学习率,就调用update_lr函数,更新学习率。

def update_lr(o,args,epoch):
  if epoch%args.lrstep == 0:
    o.param_groups[0]['lr'] = args.lrhigh
  else:
    o.param_groups[0]['lr'] -= args.lrchange
parser.add_argument("-epochs",default=20,type=int)
parser.add_argument("-lrstep",default=4, type=int,help='steps in cycle')
parser.add_argument("-lrhigh",default=0.5,type=float,help="high learning rate for cycling")
args.lrchange = (args.lrhigh - args.lr)/args.lrstep

epoch的默认值为20,lrstep是学习率在循环中的次数。lrhigh是高学习率。这里在训练中动态调整学习率,o.param_groups是长度为6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]这6个参数。

    torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
    if vloss > lastloss:
      if args.lrdecay:
        print("decay lr")
        o.param_groups[0]['lr'] *= 0.5
    lastloss = vloss
parser.add_argument("-lrdecay",action="store_true",help="use learning rate decay")

后面保存学习率。如果用衰减学习率,就将学习率*0.5。最终的损失数等于vloss参数。

至此,train.py的核心代码和关键代码已经全部分析完毕,接下来我将和小组成员进行交流,并进行其他关键代码的分析。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-26 22:09:34  更:2021-12-26 22:14:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 20:34:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码