IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> torch.gather()用法详解 -> 正文阅读

[人工智能]torch.gather()用法详解

官方示例

链接:https://pytorch.org/docs/stable/generated/torch.gather.html

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
# 沿由 dim 指定的轴收集input的值,其输出形状与index相同。
  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

对于一个3维的tensor,其输出如下:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

图解

参考:https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms

torch.gather 通过沿输入维度 dim 从每一行获取值,从输入张量创建一个新张量。 torch.LongTensor 中的值作为索引传递,指定从每个“行”中获取的值。 输出张量的维度与索引张量的维度相同。 以下图片能更清楚地解释它:
在这里插入图片描述

图1:torch.gather二维用法

label smoothing 中的用法

参考:https://www.pythonfixing.com/2021/11/fixed-label-smoothing-in-pytorch.html

假设GT是 ( 0 , 1 , 0 ) (0, 1, 0) (0,1,0),设label smoothing的系数 α = 0.2 \alpha=0.2 α=0.2,我们想象中的结果应该是 ( 0.1 , 0.8 , 0.1 ) (0.1,0.8,0.1) (0.1,0.8,0.1),以上标签是上述参考的结果。然而实际上在pytorchtensorflow中都不是这么实现的,它们俩得到的标签应该是 ( 0.8 + 0.2 / 3 , 0.2 / 3 , 0.2 / 3 ) (0.8+0.2/3,0.2/3,0.2/3) (0.8+0.2/3,0.2/3,0.2/3),再进行交叉熵计算得到结果。

import torch
import torch.nn as nn

# label smoothing
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self):
        super(LabelSmoothingCrossEntropy, self).__init__()
    def forward(self, x, target, smoothing=0.1):
        confidence = 1. - smoothing
        logprobs = x.log_softmax(dim=-1)
        print(logprobs.shape)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1) # cross_entropy loss without mean
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + smoothing * smooth_loss
        return loss.mean()

ce0 = LabelSmoothingCrossEntropy()
ce1 = nn.CrossEntropyLoss(label_smoothing=0.1)

predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
                                 [0, 0.9, 0.2, 0.2, 1], 
                                 [1, 0.2, 0.7, 0.9, 1]])
label = Variable(torch.LongTensor([2, 1, 0]))
out0 = ce0(predict,target)  # tensor(1.3096)
out1 = ce1(predict,target)  # tensor(1.3096)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-18 17:43:23  更:2022-04-18 17:44:26 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 3:25:47-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码