IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch的python API略读--tensor(三) -> 正文阅读

[人工智能]pytorch的python API略读--tensor(三)

作者:机器视觉全栈er
来源:cvtutorials.com

2.1.2 索引

筛选出符合某种条件的subtensor。

torch.where: 根据布尔变量的值选择tensor中的元素,用法如下:

torch.where(condition, x, y)

下面举个简单的例子:

>>> import torch
>>> cvtutorials = torch.randn(3, 4)
>>> threshold = torch.zeros(3, 4)
>>> cvtutorials
tensor([[-1.6981,  1.0443,  2.7922, -0.8736],
        [-2.0208, -0.4815, -0.1488, -0.9714],
        [ 1.1035,  0.4089,  0.6279,  2.4600]])
>>> torch.where(cvtutorials > 0, cvtutorials, threshold)
tensor([[0.0000, 1.0443, 2.7922, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [1.1035, 0.4089, 0.6279, 2.4600]])

上面torch.where函数返回tensor的某个元素的值遵循这样的选择:如果cvtutorials中的某个元素大于0,那么保留,否则设置为0,用数学公式表达如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-itiDkBU3-1645712444677)(https://gitee.com/cvtutorials/cvtutorials_picbed/raw/master/pth_hbh/image-20220224145457438.png)]

torch.index_select: 沿着某个维度,通过index对输入tensor进行筛选。用法如下:

torch.index_select(input, dim, index, *, out=None)

举个例子说明下:

>>> cvtutorials = torch.randn(2,3)
>>> cvtutorials
tensor([[-0.9935, -0.9802, -0.6104],
        [ 2.6251, -1.0099,  0.4752]])
>>> indices = torch.tensor([0, 1])
>>> torch.index_select(cvtutorials, 0, indices)
tensor([[-0.9935, -0.9802, -0.6104],
        [ 2.6251, -1.0099,  0.4752]])
>>> torch.index_select(cvtutorials, 1, indices)
tensor([[-0.9935, -0.9802],
        [ 2.6251, -1.0099]])

torch.masked_select: 根据设置的mask,返回一个一维的tensor(向量)。用法如下:

torch.masked_select(input, mask, *, out=None)

举个简单的例子:

>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 1.1016, -1.5259,  1.1065],
        [ 0.4838, -0.5521,  0.1556]])
>>> mask = torch.tensor([[False, True, True], [True, False, False]])
>>> mask
tensor([[False,  True,  True],
        [ True, False, False]])
>>> torch.masked_select(cvtutorials, mask)
tensor([-1.5259,  1.1065,  0.4838])

从中可以看出,根据mask对输入tensor相应位置的元素进行筛选,mask某位置为True,则取出tensor相应位置的元素,否则,不取出。

还有一点,mask的shape不一定和tensor一样,但是需要broadcast到tensor上,例如:

>>> cvtutorials = torch.randn(2, 3)
>>> cvtutorials
tensor([[ 0.8686,  0.0910,  1.8702],
        [ 1.8140, -1.0902,  0.7051]])
>>> mask = torch.tensor([[False, True, True]])
>>> mask
tensor([[False,  True,  True]])
>>> torch.masked_select(cvtutorials, mask)
tensor([ 0.0910,  1.8702, -1.0902,  0.7051])
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-26 11:31:25  更:2022-02-26 11:33:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 2:47:15-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码