IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 小黑深入底层:batchify_with_label数据处理 -> 正文阅读

[人工智能]小黑深入底层:batchify_with_label数据处理

import torch
import torch.autograd as autograd
def batchify_with_label(input_batch_list,gpu,num_layer,volatile_flag = False):
    """
    input_list:batch_size个7*list
        0.wordid:(max_len)
        1.biword_id:(max_len])
        2.----
        3.gaz_id:max个[[id1...],[len1...]]
        4.tag_id:[max_len]
        5.layer_gazs:(max_len,num_layer)
        6.gaz_mask:(max_len,num_layer)
    
    """
    # 遍历每个batch_size
    batch_size = len(input_batch_list)
    words = [sent[0] for sent in input_batch_list]
    biwords = [sent[1] for sent in input_batch_list]
    gazs = [sent[3] for sent in input_batch_list]
    labels = [sent[4] for sent in input_batch_list]
    layer_gazs = [sent[5] for sent in input_batch_list]
    gaz_mask = [sent[6] for sent in input_batch_list]
    # 获得最大长度
    word_seq_lengths = torch.LongTensor(list(map(len,words)))
    max_seq_len = word_seq_lengths.max()
    
    word_seq_tensor = autograd.Variable(torch.zeros((batch_size,max_seq_len))).long()
    biword_seq_tensor = autograd.Variable(torch.zeros((batch_size,max_seq_len))).long()
    label_seq_tensor = autograd.Variable(torch.zeros((batch_size,max_seq_len))).long()
    layer_gaz_tensor = torch.zeros(batch_size,max_seq_len,num_layer).long()
    mask = autograd.Variable(torch.zeros((batch_size,max_seq_len))).byte()
    gaz_mask_tensor = torch.zeros((batch_size,max_seq_len,num_layer)).byte()
    
    for idx,(seq,biseq,label,seq_len,layergaz,gazmask) in enumerate(
            zip(words,biwords,labels,word_seq_lengths,layer_gazs,gaz_mask)):
        word_seq_tensor[idx,:seq_len] = torch.LongTensor(seq)
        biword_seq_tensor[idx,:seq_len] = torch.LongTensor(biseq)
        label_seq_tensor[idx,:seq_len] = torch.LongTensor(label)
        layer_gaz_tensor[idx,:seq_len] = torch.LongTensor(layergaz)
        mask[idx,:seq_len] = torch.Tensor([1] * int(seq_len))
        gaz_mask_tensor[idx,:seq_len] = torch.LongTensor(gazmask)
    if gpu:
        word_seq_tensor = word_seq_tensor.cuda()
        biword_seq_tensor = biword_seq_tensor.cuda()
        word_seq_lengths = word_seq_lengths.cuda()
        label_seq_tensor = label_seq_tensor.cuda()
        layer_gaz_tensor = layer_gaz_tensor.cuda()
        gaz_mask_tensor = gaz_mask_tensor.cuda()
        mask = mask.cuda()
    return gazs,word_seq_tensor,biword_seq_tensor,word_seq_lengths,label_seq_tensor,layer_gaz_tensor,gaz_mask_tensor,mask
input_batch_list = [[
    [1,3,2,1,2,3],
    [2,3,4,2,2,6],
    [],
    [[[2,1],[3,3]],[],[[[2],[3]]],[],[],[]],
    [2,2,2,1,4,2],
    [[3,0,0,0],[0,6,0,0],[0,0,0,0],[9,0,0,0],[0,0,3,0],[0,0,0,0]],
    [[0,1,1,1],[1,0,1,1],[1,1,1,1],[0,1,1,1],[1,1,0,1],[1,1,1,1]]
]]
gpu,num_layer = 0,4
gazs,word_seq_tensor,biword_seq_tensor,word_seq_lengths,label_seq_tensor,layer_gaz_tensor,gaz_mask_tensor,mask = batchify_with_label(input_batch_list,gpu,num_layer)
print('gazs:',gazs)
print('word_seq_tensor:',word_seq_tensor)
print('biword_seq_tensor:',biword_seq_tensor)
print('word_seq_lengths:',word_seq_lengths)
print('label_seq_tensor:',label_seq_tensor)
print('layer_gaz_tensor:',layer_gaz_tensor)
print('gaz_mask_tensor:',gaz_mask_tensor)
print('mask:',mask)
输出:

gazs: [[[[2, 1], [3, 3]], [], [[[2], [3]]], [], [], []]]
word_seq_tensor: tensor([[1, 3, 2, 1, 2, 3]])
biword_seq_tensor: tensor([[2, 3, 4, 2, 2, 6]])
word_seq_lengths: tensor([6])
label_seq_tensor: tensor([[2, 2, 2, 1, 4, 2]])
layer_gaz_tensor: tensor([[[3, 0, 0, 0],
[0, 6, 0, 0],
[0, 0, 0, 0],
[9, 0, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 0]]])
gaz_mask_tensor: tensor([[[0, 1, 1, 1],
[1, 0, 1, 1],
[1, 1, 1, 1],
[0, 1, 1, 1],
[1, 1, 0, 1],
[1, 1, 1, 1]]], dtype=torch.uint8)
mask: tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.uint8)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-01 00:02:58  更:2022-04-01 00:06:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/9 0:55:59-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码