IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch按照索引取batch中的数 -> 正文阅读

[人工智能]pytorch按照索引取batch中的数

比如说bert的输出表征是基于子词的,想要用于基于词的任务,需要将词对应的最后一个子词的表征取出来,代码如下:

import torch
import torch.nn as nn

torch.manual_seed(1)

mix = torch.randn([2, 5, 3])
range_vector = torch.tensor([[0], [1]])
offsets2d = torch.tensor([[1, 3, 0], [1, 2, 4]])
print(mix)
selected_embeddings = mix[range_vector, offsets2d]
print(selected_embeddings.size())
print(selected_embeddings)

结果:


tensor([[[-1.5256, -0.7502, -0.6540],
         [-1.6095, -0.1002, -0.6092],
         [-0.9798, -1.6091, -0.7121],
         [ 0.3037, -0.7773, -0.2515],
         [-0.2223,  1.6871, -0.3206]],

        [[-0.2993,  1.8793, -0.0721],
         [ 0.1578, -0.7735,  0.1991],
         [ 0.0457, -1.3924,  2.6891],
         [-0.1110,  0.2927, -0.1578],
         [-0.0288,  2.3571, -1.0373]]])
torch.Size([2, 3, 3])
tensor([[[-1.6095, -0.1002, -0.6092],
         [ 0.3037, -0.7773, -0.2515],
         [-1.5256, -0.7502, -0.6540]],

        [[ 0.1578, -0.7735,  0.1991],
         [ 0.0457, -1.3924,  2.6891],
         [-0.0288,  2.3571, -1.0373]]])

Process finished with exit code 0

此段代码,在模型转成onnx时候会报错,可改成

mix = torch.randn([2, 5, 3]).cuda()
offsets = torch.tensor([[1, 3, 0], [1, 2, 4]]).cuda()

# 按索引取数
B, S, D = mix.size()
new_mix = mix.view(-1, D)
_, W = offsets.size()
right_add = torch.arange(0, B).unsqueeze(-1).cuda()
right_add = right_add * S
right_add.expand([B, W])

new_offsets = right_add + offsets

new_offsets = new_offsets.view(-1)
print(new_offsets)
out1 = new_mix.index_select(0, new_offsets)
# index_select 必须是一维向量
# torch.gather输出维度和输入的维度必须相同
print(out1.view(B, W, -1))

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-23 12:29:25  更:2021-10-23 12:33:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 8:46:01-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码