IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 探究torchAudio中wav2vec2的源码(二)——特征提取 -> 正文阅读

[人工智能]探究torchAudio中wav2vec2的源码(二)——特征提取

前文再续,书接上一回,我们看看wav2vec2怎么提取特征。

在论文中,wav2vec2是通过conv1d进行特征提取的。如下图:

请添加图片描述

而conv1d的具体结构也已经给出:

请添加图片描述

extractor_conv_layer_config列表中的三元组的含义分别表示:out_channels、kernel_size、stride。

然后我们看看feature_extractor是怎么生成的。

特征提取

我们跟着提示点进components.py文件中找到_get_frature_extractor方法。
请添加图片描述

主要工作:

这个函数做的主要就是构造一个conv1d模型,用于对原音频的特征提取。

三个形参为:norm_mode表示归一化的模式选择、shapes表示conv1d的结构参数、bias表示是否设置偏置值

里面的归一化模型参数有:

  • GroupNorm:将channel方向分group,然后每个group内做归一化,算(C//G)*H*W的均值
  • LayerNorm:channel方向做归一化,算C*H*W的均值

总体结构如下代码:

def _get_feature_extractor(
    norm_mode: str,
    shapes: List[Tuple[int, int, int]],
    bias: bool,
) -> FeatureExtractor:
    assert norm_mode in ["group_norm", "layer_norm"]
    # 块
    blocks = []
    # 输入:语音
    in_channels = 1
    # 穷举结构参数
    for i, (out_channels, kernel_size, stride) in enumerate(shapes):
        # 归一化模型
        normalization = None
        # 如果归一化模型是组归一化
        if norm_mode == "group_norm" and i == 0:
            # 组归一化模型构造
            normalization = nn.GroupNorm(
                num_groups=out_channels,
                num_channels=out_channels,
                affine=True,
            )
        # 如果是层归一化
        elif norm_mode == "layer_norm":
            # 层归一化模型构造
            normalization = LayerNorm(
                normalized_shape=out_channels,
                elementwise_affine=True,
            )
        # conv1d模型构造,把归一化也传进去
        blocks.append(
            ConvLayerBlock(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                bias=bias,
                layer_norm=normalization,
            )
        )
        # 输出channel变输入channel(1变512、512变512.....)
        in_channels = out_channels
        # 把列表blocks转换成pytorch模型列表然后返回提取特征的结果
    return FeatureExtractor(nn.ModuleList(blocks))

ConvLayerBlock对象

点开ConvLayerBlock对象查看相关代码:

请添加图片描述

这是ConvLayerBlock对象的初始化代码。就是设计conv1d的结构。然后还有forword函数如下(forword函数就是把数据传入模型中跑的函数):

请添加图片描述

FeatureExtractor对象

接下来看看_get_feature_extractor函数要返回的FeatureExtractor对象。

请添加图片描述

初始化就是把conv1d的nn.ModuleList模型传进去。主要就是forword函数了。

这里的116行增加一个维度是为了符合in_channel=1的设定。

这里的119行的layer就是调用了ConvLayerBlock对象中的forword函数。

总结

请添加图片描述

特征提取需要用到_get_frature_extractor方法,其中_get_frature_extractor方法主要调用了ConvLayerBlock对象和FeatureExtractor对象进行特征提取模型的构建和使用。

下一个博客我们看transfromer_encoder的构建。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-07 22:41:41  更:2022-04-07 22:45:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/16 0:44:27-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码