这一次朱毅博士给大家精读的论文是 CLIP ,来自于 OpenAI ,是图像文本多模态领域一个里程碑式的工作。
项目链接为:https://openai.com/blog/clip/
0. 前言
CLIP 自从去年2月底提出就立马火爆全场,它的方法出奇的简单、但是效果又出奇的好,很多结果和结论都让人瞠目结舌。比如作者说 CLIP 的迁移学习能力是非常强的,预训练好的这个模型能够在任意一个视觉分类的数据集上取得不错的效果;而且最重要的是它是 zero-shot 的,意思就是说它完全没有在这些数据集上去做训练 ,就能得到这么好的效果。论文里做了非常多的实验,在超过30个数据集上去做了测试,涵盖的面非常广,包括了 OCR 、视频动作检测、还有坐标定位和许多细分类任务。在所有这些结果之中,其中最炸裂的一条就是在 ImageNet 上的结果了,CLIP 在不使用 ImageNet 的训练集的情况下,也就是不使用128万张图片中任何一张的情况下直接 zero-shot 做推理,就能获得和之前有监督训练好的 Res50 同样的效果,在 CLIP 这篇工作出来之前,很多人都认为这是不可能的事情。
下面就先来看文章的图一,也就是 CLIP 的模型总览图。这里先大概说一下它的流程,具体的细节后面还会讲到。
首先 CLIP 是如何进行预训练的呢? 通过论文题目就可以大概猜到一二,题目的意思是说:通过自然语言处理 的监督信号可以去训练一个迁移效果很好的视觉模型 ,所以很自然的这是一个涉及文字 、图片的一个多模态的工作 。那它是怎么去利用自然语言处理来的监督信号呢?其实在训练的过程中,模型的输入是一个图片和文字的配对 ,比如这里图片里画的是一只狗,配对的文字是 pepper ,是一只小狗;然后图片通过一个图片的编码器从而得到一些特征,这里的编码器既可以是 ResNet ,也可以是 Vision Transformer ,对于这个句子来说,它也会通过一个文本的编码器从而得到文本的特征。假设说现在每个 training batch 都有
n
n
n 个图片文本对,也就是这里有
n
n
n 张图片,有
n
n
n 个句子,那我们就会得到
n
n
n 个图片的特征,还有
n
n
n 个文本的特征。然后 CLIP 就是在这些特征上去做对比学习 ,之前也提过对比学习非常的灵活,只需要一个正样本和负样本的定义,其他都是正常套路。那这里什么是正样本什么是负样本?其实再明确不过了,这里配对的一个图片文本对就是一个正样本 ,因此特征矩阵里沿着对角线方向上都是正样本 。因为
I
1
T
1
,
I
2
T
2
I_1T_1 ,I_2 T_2
I1?T1?,I2?T2? 这些本身就都是配对的,剩下矩阵里所有不是对角线上的元素就都是负样本了。就是说这里我们是有
n
n
n 个正样本,然后有
n
2
?
n
n^2 - n
n2?n 个负样本。
一旦有了正负样本,模型就可以通过对比学习的方式去训练起来了,完全不需要任何手工的标注。当然,对于这种无监督的预训练方式,是需要大量的数据的。OpenAI 还专门去收集了一个数据集,里面有4亿个图片和文本对,而且这个数据应该是清理的非常好的,质量应该非常高。这也是 CLIP 预训练模型为什么能这么强大的主要原因之一。
那接下来就是 CLIP 如何去做 zero-shot 的推理了? 这部分是非常有意思的。因为 CLIP 的这个模型经过预训练之后,其实只能去得到一些视觉上和文本上的特征,并没有在任何分类的任务上去做继续的训练或者微调,是没有这么一个分类头的。那如果没有分类头怎么去做推理呢? 作者这里就想出来一个巧妙的利用自然语言的一种方法 prompt template ,这里拿 ImageNet 做个例子。CLIP 先把 ImageNet 里这1000个类,比如说图一中的飞机、汽车、狗变成一个句子,用这些物体去替代这里的 object ,就是把一个单词变成了这么一个句子。ImageNet 有1000个类,这里就会生成1000个句子,然后这1000个句子通过我们之前预训练好的文本编码器就会得到1000个文本的特征。那我们为什么要做 prompt template 呢?其实直接用这些单词去抽取这种文本的特征也是可以的,但是因为在模型预训练的时候,图片每次看到的基本都是一个句子,如果在推理的时候突然把所有的文本都变成一个单词,就跟在训练的时候看到的这个文本不太一样了,效果就会稍有下降。而且怎么变成句子也是很有讲究的, CLIP 这篇论文后面还提出了 prompt engineering 和 prompt ensemble 这两种方式去进一步的提高模型的准确率,而不需要重新训练模型。在推理的时候,不论输入任何一张照片,只要把这张照片扔给图片的编码器,得到了图片特征之后,就去拿这个图片的特征去跟所有的文本特征去做 cosine similarity 计算。图像的特征跟哪个文本特征最相似,就把这个文本特征所对应的那个句子挑出来,从而完成了分类任务。
是不是非常的巧妙,其实当 CLIP 真正使用的时候,这里的标签也是可以改的,不光是 ImageNet 这1000个类,可以换成任何的单词。这里的图片也不需要是 ImageNet 的图片,也可以是任何的图片。依旧可以通过这种算相似度的方式去判断出这张图片里到底有哪些物体。比如说这里给一张三轮车的照片,然后在上面的类别里也加上三轮车类别,通过 CLIP 的这种 zero-shot 的推理的方式很有可能这张图片就能正确的被分类成三轮车。但如果像之前那种严格按照1000类去训练的分类头来说的话,模型永远都不会判断出这张图片是三轮车,最多也就是把它判断成是车或者是自行车。
这个性质才是 CLIP 的强大之处,也是 CLIP 这个模型最吸引人的地方。因为它彻底摆脱了 categorical label 的限制,也就是说不论是在训练的时候还是在推理的时候,都不需要有这么一个提前定以好的标签列表了,任意给一张照片,都可以通过给模型去喂不同的文本句子从而知道这张图片里到底有没有感兴趣的物体。
而且 CLIP 不光是能识别新的物体,由于它真的把视觉的语义和文字的语义联系到了一起,所以它学到的特征语义性非常强,迁移的效果也非常的好。在 OpenAI 的官网上,也就是 CLIP 的官方博客上作者还举了这么个例子。在 ImageNet 数据集上之前训练的 ResNet101 是76.2的准确率,然后对于 CLIP 训练出来的 VIT-Large 也是76.2的准确率,这两个的准确率是一样的。但是当我们换个数据集,比如说换到 ImageNetV2 或者 ImageNet rendition、ObjectNet、ImageNet sketch 或者 ImageNet Adversarial 之后,就会发现之前这种严格按照1000类分类头训练出来的模型,准确率下降的非常快。在换到素描画的时候或者对抗性样本的时候,准确率直接从70多掉到20多,甚至到对抗性的这种 2.7,基本上就已经是在随机猜了。迁移的效果是惨不忍睹,但是对于 CLIP 训练出来的模型来说,效果始终都非常的高,没有什么下降。这也就从侧面说明了因为和自然语言处理的结合,导致 CLIP 学出来的视觉特征和我们用语言所描述的某个物体已经产生了强烈的联系。比如说这里的香蕉不论是在自然图像里出现的香蕉、还是动漫里的香蕉、还是说素描的香蕉或者是加过对抗性样本的香蕉,CLIP 训练出来的模型都知道它对应的是香蕉这个单词。所以他能在 domain 变化剧烈的情况下,依旧能识别出这张图片里画的是香蕉,准确的完成这个分类任务。
这里再介绍一个基于 CLIP 的有趣的应用(ViLD),目标检测。在 CLIP 出来之后,很快也就是一个半月左右,Google 就出了一篇利用 CLIP 去做物体检测的工作。可以来看一下效果怎么样,作者这里说如果你是用传统的目标检测的方法去做预测的话,模型可能只能告诉你这些都是玩具,也就是这里的蓝色的这种基础类,但是当你利用了自然语言之后,就摆脱了基础类的限制,就可以随意发挥了。也就是这里说的 open vocabulary detector ,训练出来的模型就可以检测出这些新的类,也就是红色的这些类,比如说不光可以知道这些玩具的颜色,同时还知道这些玩具具体所代表的物体类别,比如这是一个玩具大象、一个玩具鳄鱼、一个玩具鸭子,输出一下就丰富了不少。
1. 题目和作者
首先我们来看一下 CLIP 这篇论文的整体架构。CLIP 这篇论文有48页,就算是把补充材料的10页去掉,CLIP 的正文也有30多页。其中大部分的篇幅都是留给了实验和相应的一些分析。从头开始这里面这一块是摘要,然后接下来一页多的内容主要是在讲引言,然后接下来的两页就是讲了一下方法,主要说的是怎么做预训练;然后接下来从第6页一直到第18页全都是说的实验,当然这里面也包括了怎么去做这种 zero-shot 的推理、还有包括这种 prompt engineering,prompt ensemble ,算是方法和实验的一个合体;然后讲完了实验,作者大概花了一页的篇幅去讨论了一下 CLIP 这个工作的一些局限性;然后接下来的五页作者主要就是讨论了一下 CLIP 这篇工作有可能能带来的巨大的影响力,在这个部分作者首先讨论的是 bias ,就是一些模型的偏见,然后讨论了 CLIP 有可能在监控视频里的一些应用;然后最后作者展望了一下 CLIP 还有哪些可以做的这个未来工作;然后作者用了一页的篇幅说了一下相关工作,最后因为该讨论的都讨论了,该说的都说完了,结论其实非常短,就这么一小段,写的 还没有后面的致谢多。然后因为 CLIP 做的这个数据集比较多,比的方法也比较多,牵扯到的工作也比较多,因为是视觉、 nlp 、多模态、包括有监督学习、无监督学习、自监督学习,很多领域里的工作都要提到,所以说可以看到引用文献在这种双栏的情况下就写了整整八页还多,跟一篇顶会投稿是一个长度。
先看论文题目 ,题目意思是:利用自然语言的监督信号去学习一个可迁移的视觉网络。
这里有两个关键词,一个是可迁移 ,另一个就是利用自然语言的监督信号 ,所以怎么利用自然语言的监督信号就是这篇论文的贡献所在;至于想达到的目的主要就是迁移性 ,就是想去学一个泛化性非常好的特征,从而能在各种数据集或者各种任务上能够不需要训练,直接推理都能获得不错的效果。
作者团队全部来自 OpenAI ,有12个作者,但其实对于 CLIP 这种工作,做了这么多实验、刷了这么多数据集,光论文就写了48页,12个人其实并不多,论文后面还致谢了很多很多人。
2. 摘要
下面是论文摘要 ,摘要总共有9句话。
- 第1句话说现在最先进的计算机视觉系统是怎么训练的? 都是先有一个固定的已经提前定义好的物体类别集合,然后模型通过去预测这些提前定义好的类别从而完成模型的训练。这个固定的提前定义好的标签集合怎么理解呢?就好比
ImageNet 它有固定的1000个类、CIFAR10 就有10个类、CIFAR100 就是100个类、目标检测 COCO 就是80个类、语义分割 Cityscapes 就有19个类、视频分类 Kinetics 数据集就有400个类。总之为了简单起见,不光是收集数据集的简单性还是说从模型训练的简单性来说,直接定义这么一个固定的提前定义好的标签集合会大大的简化问题本身。 - 第2句话说因为采用了这种有限制性的监督信号,
从而也限制了模型本身的泛化性 ,尤其是当你需要去识别新物体类别的时候。对于这些新的类别难道每次都要去收集新的数据,然后重头训练一个新的模型吗?这样就很有局限性,就不好 scale 了。 - 第3句话说作者想到了另外一种方式,
直接从关于图像的文本里去学习监督信号是一个看起来非常有前途的办法 ,因为它的监督信号涵盖的范围就太广了。只要是你语言描述过的物体就有可能让视觉模型去识别到这个物体,而不仅仅是提前定义好的那1000个类。 - 然后作者接下来说他们已经证实了,用一个非常简单的一个预训练的任务,就可以非常高效的且可扩展的去学习一些最好的这个图像的表征。那这个任务具体是什么呢?其实是给定一些图片,然后又给定一些句子,
模型需要去判断哪一个句子跟图片是配对的 。既然要做这么一个配对任务,当然就需要一个类似的数据集了,也就是说训练样本必须是一个图片和文字的配对。文章作者又去爬了一个超级大的有4个亿的图片文本配对的数据集。有了这么大的一个数据集之后,就可以选择一种自监督的训练方式去预训练一个大模型出来了。 - 在预训练完成之后,自然语言就被用来去引导视觉模型去做物体的分类,也就是之前说的
prompt ,然后分类也不光局限于已经学到的视觉概念,还能扩展到新的类别。现在学到的模型是能够直接在下游任务上去做 zero-shot 的推理的。 - 第6-9句作者介绍了模型效果。为了验证模型的有效性,作者接下来在超过30个不同的视觉任务和数据集上做了测试。然后作者发现模型的迁移学习效果对大多数任务来说效果都是非常好的。
CLIP 模型在不需要任何数据集专门训练的情况下,能和之前一个完全用有监督方式训练出来的模型打平手、甚至还会更高。作者在这里取了一个 ImageNet 的效果,CLIP 模型在不使用任何一张128万个训练集图片的情况下,能跟之前一个有监督训练好的 ResNet-50 打成平手,这个结果在当时读起来是非常震惊的。作者说他们预训练好的模型和他们的代码都放到了这个链接里,但其实这里的代码只是可以用来做推理,OpenAI 并没有开源他真正的预训练的代码。
3. 引言
下面是论文引言 部分,先看前2段。
文章一开始说直接从原始的文本数据里去预训练一个模型已经在过去几年里,在 NLP 领域取得了革命性的成功,比如说 BERT, GPT 这些模型。不论是使用自回归预测的方式还是使用掩码完形填空的方式,都是一种自监督的训练方式,目标函数是跟下游任务无关的,只是想通过预训练得到一个非常好的能泛化的特征,随着计算资源的增多、模型的变大还有数据变得更多,模型的能力也会稳健的提升。那这一套系统说白了其实就是文字进文字出,并不是在做一个什么特殊的分类任务,模型架构也是跟下游任务无关的。当直接用在下游任务上的时候,就不需要费尽心思去研究一个针对下游任务的输出头或者说针对那个数据集的特殊的一些处理了。这里面最厉害的最耳熟能详的模型就是 OpenAI 自己的 GPT-3 ,能够做分类、做翻译、还能写邮件、写小说、写新闻。而且在大多数任务上并不需要特定领域的数据或者说只需要一点点的数据去做一点微调,就可以和之前精心设计过的那些网络取得差不多的结果。
作者接下来说这些鼓舞人心的结果证实了,在这种文本进文本出,利用自监督的信号去训练整个模型的框架下。这种大规模的没有标注的数据其实是要比那些手工标注的质量非常高的数据集反而是要更好使的。但是在视觉领域大家一般的做法还是在 ImageNet 这种数据集上去训练一个模型,这会让模型有诸多的限制,NLP 里的这套框架到底能不能用在视觉里呢?作者说从之前的工作看起来应该是没问题。
然后接下来作者就讨论了一下之前的工作,作者列了三个工作 VirTex, ICMLM , ConVIRT 。这三个工作都是基于 Transformer 去做的,跟 CLIP 已经非常非常像了,但具体的做法还是有一些区别。比如 VirTex 用的是自回归的预测方式去做模型的预训练,ICMLM 是用完形填空的方式去做预训练,ConVIRT 就跟 CLIP 已经非常类似了,但是只在医疗图像上做了实验。总之这三种方法都没有在模型或者数据上取得一个很大的规模,所以说就没有像 CLIP 模型一样一战成名。
然后作者反思了一下说,既然利用自然语言的监督信号是一个很有前途的方向,那为什么在视觉里这一系列的工作就这么少呢?这是因为之前的那些方法没有这么大的数据集、没有这么多的算力、没有这么大的模型、没有这么好的自监督训练的方式。在标准的测试数据集上,比如说 ImageNet 上的效果就非常的差,比如我们刚才说的这个工作在 zero-shot 的设定下就只有11.5的这个准确率,但是在 ImageNet 上之前最好的表现方法都已经有88.4了,而且就算是不用 Deep Learning ,就用传统的视觉方法准确度也有50。因为效果实在太低,所以说没有实用的价值,大家去钻研这个方向的热情就会小很多。取而代之呢,另一系列的工作就非常受人关注,就是说怎么去利用更弱的监督信号。之前一个工作就提出了一个新的预训练的任务,收集了一个 instagram 的数据集,去预测图片自身带的 Hash Tag ,通过这种方法呢,数据集就可以变得非常大,因为每张图片都会自带一些 Hash Tag ,不用去人工的标注。同时这些 Hash Tag 也就可以想象成是一种自然语言的监督信号,是一个或者几个单词,有明确的语义含义。然后另外还有一些工作,比如说这两篇工作就是去训练一个模型,然后去预测 JFT-300 Million 这个数据集上所有的这个类别。因为 JFT-300 Million 数据集非常大,类别数好像有18000,因为有这么多的类别,他的标注其实是比较 noisy 的,也算是一种弱的监督信号。总之作者想说的是,这一系列的工作之所以更火爆,就是因为他们的性能会更好。
4. 模型
5. 实验
6. 模型局限性
7. 结论和总结
8. CLIP预训练demo
这里使用 OpenAI 提供的 notebook 演示 CLIP 的效果,地址为: https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb。
安装 CLIP :
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
导入需要的库,PyTorch 版本在 1.7.1及以上:
import numpy as np
import torch
from pkg_resources import packaging
print("Torch version:", torch.__version__)
下面是加载模型:
import clip
clip.available_models()
'''
['RN50',
'RN101',
'RN50x4',
'RN50x16',
'RN50x64',
'ViT-B/32',
'ViT-B/16',
'ViT-L/14',
'ViT-L/14@336px']
'''
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
'''
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408
'''
下面是图片预处理,即 preprocess :
Compose(
Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
然后是文本预处理,得到77个 tokens :
clip.tokenize("Hello World!")
'''
tensor([[49406, 3306, 1002, 256, 49407, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
'''
然后是输入图片和文本对,这里输入的图片文本对为8对:
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}
original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}\n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
plt.tight_layout()
下面是得到图像和文本特征:
image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
计算余弦相似性,可以看到对角线上的相似性值最高:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])
plt.title("Cosine similarity between text and image features", size=20)
下面进行 zero-shot 分类,这里使用 CIFAR100 数据集,可能是类别比较少的原因,第一张图这里分错了:
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show()
|