IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> numpy 的transpose是如何实现的 -> 正文阅读

[人工智能]numpy 的transpose是如何实现的

背景

transpose在深度学习中是很常见的一个操作, numpy和pytorch都有对应的操作, 但是内部是如何实现的呢? stackoverflow上有很相信的说明, 这里搬运下.

transpose 是如何工作的?

  • 定义一个数组看看transpose的结果如何
In [28]: arr = np.arange(16).reshape((2, 2, 4))

In [29]: arr
Out[29]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],

       [[ 8,  9, 10, 11],
        [12, 13, 14, 15]]])


In [32]: arr.transpose((1, 0, 2))
Out[32]: 
array([[[ 0,  1,  2,  3],
        [ 8,  9, 10, 11]],

       [[ 4,  5,  6,  7],
        [12, 13, 14, 15]]])

在np.array中, 对于三维数组的三个轴的定义如下:
在这里插入图片描述
数组内部实现上其实是使用一块连续内存保存数据的,在内存空间里, 这些数据的保存形式:

上图中的64 bytes, 32 bytes, 8bytes,为0, 1, 2三个轴的stride,换句话说, 在三个轴上取数时要用不同的stride跳跃, 轴0上每增加一位则要跳64bytes, 假如取(i, j, k)位的数arr[i, j, k], 那么可以知道:

# 这里的strides表示的位数,不是byte, 对于上图strides则为[8, 4, 1]
idx = strides[0] * i + strides[1] * j + strides[2] *k
arr[i, j, k] = buffer[idx]

当做arr.transpose(1, 0, 2)操作时, 需要将每个轴的dim和stride都换下,

strides变化:[64, 32, 8] ----------> [32, 64, 8]
shapes 变化:[2, 2, 4] ----------------> [2, 2, 4]
这样就避免了内存拷贝, transpose几乎无时间消耗.
在这里插入图片描述
在这里插入图片描述

代码实现

下面定义一个tensor来实现 transpose, 注意transpose时仅仅对strides, shapes做了顺序交换. 这里的strides表示的element移动个数,而不是bytes.

class Tensor:
    def __init__(self, data_lst, shapes):
        assert len(shapes) == 3
        self.data = data_lst
        self.shapes = np.array(list(shapes))
        self.strides = np.array([shapes[1] * shapes[2], shapes[2], 1])

    def print(self) -> str:
        print('(')
        for i in range(self.shapes[0]):
            print('[', end = '')
            for j in range(self.shapes[1]):
                print('[', end = '')
                for k in range(self.shapes[2]):
                    idx = i * self.strides[0] + j * self.strides[1] + k * self.strides[2]
                    print(self.data[idx], ' ', end = '')
                print(']', end = '') 
            print(']')
        print(')')
    
    def transpose(self, axes):
        assert len(axes) == len(self.shapes)
        axes = list(axes)
        self.shapes = self.shapes[axes]
        self.strides = self.strides[axes]
        return self

    def numpy(self):
        """ convert to numpy array

        Returns:
            _type_: np.ndarray
        """
        n_elements = self.shapes[0] * self.shapes[1] * self.shapes[2]
        arr = np.zeros((n_elements,))
        target_idx = 0
        for i in range(self.shapes[0]):
            for j in range(self.shapes[1]):
                for k in range(self.shapes[2]):
                    src_idx = i * self.strides[0] + j * self.strides[1] + k * self.strides[2]
                    arr[target_idx] = self.data[src_idx]
                    target_idx += 1
        return arr.reshape(self.shapes)

def test_tensor():
    axes = [1, 0, 2]
    arr = np.arange(16).reshape((2, 2, 4))
    t = Tensor(arr.reshape(-1).tolist(), (2, 2, 4))
    print('original arr:')
    t.print()
    print('numpy.transpose:', np.transpose(arr, axes))
    ret = t.transpose(axes).numpy()
    print('after transpose:', ret)
    assert np.allclose(np.transpose(arr, axes), ret)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-22 20:35:25  更:2022-02-22 20:35:30 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 11:21:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码