| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> GIT:斯坦福大学提出应对复杂变换的不变性提升方法 | ICLR 2022 -> 正文阅读 |
|
[人工智能]GIT:斯坦福大学提出应对复杂变换的不变性提升方法 | ICLR 2022 |
论文: Do Deep Networks Transfer Invariances Across Classes?
Introduction? 优秀的泛化能力需要模型具备忽略不相关细节的能力,比如分类器应该对图像的目标是猫还是狗进行响应,而不是背景或光照条件。换句话说,泛化能力需要包含对复杂但不影响预测结果的变换的不变性。在给定足够多的不同图片的情况下,比如训练数据集包含在大量不同背景下的猫和狗的图像,深度神经网络的确可以学习到不变性。但如果狗类的所有训练图片都是草地背景,那分类器很可能会误判房子背景中的狗为猫,这种情况往往就是不平衡数据集存在的问题。 Measuring Invariance Transfer In Class-Imbalanced Datasets? 论文先对不平衡场景中的不变性进行介绍,随后定义一个用于度量不变性的指标,最后再分析不变性与类别大小之间的关系。 Setup:Classification,Imbalance,and Invariances? 定义输入
(
x
,
y
)
(x,y)
(x,y),标签
y
y
y属于
{
1
,
?
?
,
C
}
\{1,\cdots,C\}
{1,?,C},
C
C
C为类别数。定义训练后的模型的权值
w
w
w,用于预测条件概率
P
~
w
(
y
=
j
∣
x
)
\tilde{P}_w(y=j|x)
P~w?(y=j∣x),分类器将选择概率最大的类别
j
j
j作为输出。给定训练集
{
(
x
(
i
)
,
y
(
i
)
)
}
i
=
1
N
~
P
t
r
a
i
n
\{(x^{(i)}, y^{(i)})\}^N_{i=1}\sim \mathbb{P}_{train}
{(x(i),y(i))}i=1N?~Ptrain?,通过经验风险最小化(ERM)来最小化训练样本的平均损失。但在不平衡场景下,由于
{
y
(
i
)
}
\{y^{(i)}\}
{y(i)}的分布不是均匀的,导致ERM在少数类别上表现不佳。 Measuring Learned Invariacnes? 为了度量分类器学习不变性的程度,论文定义了原输入和变换输入之间的期望KL散度(eKLD): ? 这是一个非负数,eKLD越低代表不变性程度就越高,对
T
T
T完全不变的分类器的eKLD为0。如果有办法采样
x
′
~
T
(
?
∣
x
)
x^{'}\sim T(\cdot|x)
x′~T(?∣x),就能计算训练后的分类器的eKLD。此外,为了研究不变性与类图片数量的关系,可以通过分别计算类特定的eKLD进行分析,即将公式2的
x
x
x限定为类别
j
j
j所属。 ? 训练方面,采用标准ERM和CE+DRS两种方法,其中CE+DRS基于交叉熵损失进行延迟的类平衡重采样。DRS在开始阶段跟ERM一样随机采样,随后再切换为类平衡采样进行训练。论文为每个训练集进行两种分类器的训练,随后计算每个分类器每个类别的eKLD指标。结果如图1所示,可以看到两个现象:
Trasnferring Invariances with Generative Models? 从前面的分析可以看到,长尾数据集的尾部类对复杂变换的不变性较差。下面将介绍如何通过生成式不变性变换(GIT)来显式学习数据集中的复杂变换分布 T ( ? ∣ x ) T(\cdot|x) T(?∣x),进而在类间转移不变性。 Learning Nuisance Transformations from Data? 如果有数据集实际相关的复杂变换的方法,可以直接将其用作数据增强来加强所有类的不变性,但在实践中很少出现这种情况。于是论文提出GIT,通过训练input conditioned的生成模型 T ~ ( ? ∣ x ) \tilde{T}(\cdot|x) T~(?∣x)来近似真实的复杂变换分布 T ( ? ∣ x ) T(\cdot|x) T(?∣x)。 ? 论文参考了多模态图像转换模型MUNIT来构造生成模型,该类模型能够从数据中学习到多种复杂变换,然后对输入进行变换生成不同的输出。论文对MUNIT进行了少量修改,使其能够学习单数据集图片之间的变换,而不是两个不同域数据集之间的变换。从图2的生成结果来看,生成模型能够很好地捕捉数据集中的复杂变换,即使是尾部类也有不错的效果。需要注意的是,MUNIT是非必须的,也可以尝试其它可能更好的方法。 ? 在训练中,论文设置阈值 K K K,仅图片数量少于 K K K的类进行数据增强。此外,仅对每个batch的 p p p比例进行增强。 p p p一般取0.5,而 K K K根据数据集可以设为20-500,整体逻辑如算法1所示。 GIT Improves Invariance on Smaller Classes? 论文基于算法1进行了实验,将Batch Sampler设为延迟重采样(DRS),Update Classifier使用交叉熵梯度更新,整体模型标记为 C E + D R S + G I T ( a l l c l a s s e s ) CE+DRS+GIT(all classes) CE+DRS+GIT(allclasses)。all classes表示禁用阈值 K K K,仅对K49数据集使用。作为对比,Oracle则是用于构造生成数据集的真实变换。从图3的对比结果可以看到,GIT能够有效地增强尾部类的不变性,但同时也损害了图片充裕的头部类的不变性,这表明了阈值 K K K的必要性。 Experiment? 不同训练策略搭配GIT的效果对比。 ? 在GTSRB和CIFAR数据集上的变换输出。 ? CIFAR-10上每个类的准确率。 ? 对比实验,包括阈值 K K K对性能的影响,GTSRB-LT, CIFAR-10 LT和CIFAR-100 LT分别取25、500和100。这里的最好性能貌似都比RandAugment差点,有可能是因为论文还没对实验进行调参,而是直接复用了RandAugment的实验参数。这里比较好奇的是,如果在训练生成模型的时候加上RandAugment,说不定性能会更好。 Conclusion? 论文对长尾数据集中的复杂变换不变性进行了研究,发现不变性在很大程度上取决于类别的图片数量,实际上分类器并不能将从大类中学习到的不变性转移到小类中。为此,论文提出了GIT生成模型,从数据集中学习到类无关的复杂变换,从而在训练时对小类进行有效增强,整体效果不错。 ?
|
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 | -2024/11/26 1:45:19- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |