Sentence-transformer + FAISS 语义搜索
1. 关于FAISS
在上一篇博客中介绍了SBERT系列的各种应用,其中包括了语义搜索。但是在进行语义搜索的时候采用的是全局搜索的方法,时间代价较高,因而在数据库很大的情况下,需要采取一些近似计算的策略,其中FAISS就是一个非常有效的工具。
FAISS是由 Facebook AI Research开发一款基于C++的对稠密向量的语义检索和聚类工具,并提供了完整的python的封装。
由于其对稠密向量直接进行计算,这就使之可以具有非常多样的下游任务,理论上任何可以向量化的特征都可以利用这一工具实现快速检索,例如图片,文本等。
sentence-transformer模块提供的样例semantic_search_quora_faiss.py中,就包含了对结合FAISS使用的介绍,这篇博客基本内容就是对这个样例进行搬运,并做简单的注释。
2. 安装FAISS
打开pypi搜索FAISS,出现的第一条FAISS 1.5.1不要安装,那个项目停在两年前不维护了,也没有在py3.8以上的版本上编译,所以直接下载cpu版本或者gpu版本的1.7.1.
3. 下载预训练模型和数据集
在样例中,数据集采用的是quora的重复问句数据集,点此下载。这个数据集原本是用来做句对匹配训练的,每一条中有两个question,以及这两个question是否是重复提问的label,但是在这个应用场景中,语义检索希望把用户输入问句的最相关的问句返回,并不关心重复的问题,于是将question1和question2一视同仁的看待。这个数据集里一共有500k的语句。
关于预训练模型,样例中给出的是在quora上训练的’quora-distilbert-multilingual’,而我使用的是mpnet-base,实验效果也是不错的,预训练模型的下载和使用可以参考上一篇博客。
from sentence_transformers import SentenceTransformer, util
import os
import csv
import pickle
import time
import faiss
import numpy as np
4. 获取embedding
由于数据量比较大,不可能每次检索的时候都对库里的所有语句现场编码,所以需要把所有候选语句全都编码完,保存到本地。我在实验时使用3090对全部500k的候选用mpnet-base编码完,总共耗时约5分钟。 首先设置一下embedding相关的参数:
dataset_path = "your_path_to_/quora_duplicate_questions.tsv"
max_corpus_size = 100000
embedding_cache_path = 'quora-embeddings-{}-size-{}.pkl'.format(model_name.replace('/', '_'), max_corpus_size)
embedding_size = 768
然后开始编码,并保存到本地:
if not os.path.exists(embedding_cache_path):
corpus_sentences = set()
with open(dataset_path, encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
for row in reader:
corpus_sentences.add(row['question1'])
if len(corpus_sentences) >= max_corpus_size:
break
corpus_sentences.add(row['question2'])
if len(corpus_sentences) >= max_corpus_size:
break
corpus_sentences = list(corpus_sentences)
print("Encode the corpus. This might take a while")
corpus_embeddings = model.encode(corpus_sentences, show_progress_bar=True, convert_to_numpy=True)
print("Store file on disc")
with open(embedding_cache_path, "wb") as fOut:
pickle.dump({'sentences': corpus_sentences, 'embeddings': corpus_embeddings}, fOut)
else:
print("Load pre-computed embeddings from disc")
with open(embedding_cache_path, "rb") as fIn:
cache_data = pickle.load(fIn)
corpus_sentences = cache_data['sentences']
corpus_embeddings = cache_data['embeddings']
然后这个pkl会保存到本地,100k的编码文件大概有300M大小。
5. 创建索引并训练
首先设置一下相关的参数。
top_k_hits = 10
n_clusters = 1024
_nprobe = 3
然后创建索引对象,并传入训练数据。
quantizer = faiss.IndexFlatIP(embedding_size)
index = faiss.IndexIVFFlat(quantizer, embedding_size, n_clusters, faiss.METRIC_INNER_PRODUCT)
index.nprobe = _nprobe
corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1)[:, None]
index.train(corpus_embeddings)
index.add(corpus_embeddings)
6. 应用
训练完成之后就可以应用了。
while True:
inp_question = input("Please enter a question: ")
start_time = time.time()
question_embedding = model.encode(inp_question)
question_embedding = question_embedding / np.linalg.norm(question_embedding)
question_embedding = np.expand_dims(question_embedding, axis=0)
distances, corpus_ids = index.search(question_embedding, top_k_hits)
hits = [{'corpus_id': id, 'score': score} for id, score in zip(corpus_ids[0], distances[0])]
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
end_time = time.time()
print("Input question:", inp_question)
print("Results (after {:.3f} seconds):".format(end_time-start_time))
for hit in hits[0:top_k_hits]:
print("\t{:.3f}\t{}".format(hit['score'], corpus_sentences[hit['corpus_id']]))
correct_hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k_hits)[0]
correct_hits_ids = set([hit['corpus_id'] for hit in correct_hits])
ann_corpus_ids = set([hit['corpus_id'] for hit in hits])
if len(ann_corpus_ids) != len(correct_hits_ids):
print("Approximate Nearest Neighbor returned a different number of results than expected")
recall = len(ann_corpus_ids.intersection(correct_hits_ids)) / len(correct_hits_ids)
print("\nApproximate Nearest Neighbor Recall@{}: {:.2f}".format(top_k_hits, recall * 100))
if recall < 1:
print("Missing results:")
for hit in correct_hits[0:top_k_hits]:
if hit['corpus_id'] not in ann_corpus_ids:
print("\t{:.3f}\t{}".format(hit['score'], corpus_sentences[hit['corpus_id']]))
print("\n\n========\n")
以上就是本期全部内容了,如果这篇博客对你有帮助的话,还请一键三连支持一下up主,我们下期再见。
|