IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> NLP实践——Sentence-transformer + FAISS 语义搜索 -> 正文阅读

[人工智能]NLP实践——Sentence-transformer + FAISS 语义搜索

1. 关于FAISS

上一篇博客中介绍了SBERT系列的各种应用,其中包括了语义搜索。但是在进行语义搜索的时候采用的是全局搜索的方法,时间代价较高,因而在数据库很大的情况下,需要采取一些近似计算的策略,其中FAISS就是一个非常有效的工具。

FAISS是由 Facebook AI Research开发一款基于C++的对稠密向量的语义检索和聚类工具,并提供了完整的python的封装。

由于其对稠密向量直接进行计算,这就使之可以具有非常多样的下游任务,理论上任何可以向量化的特征都可以利用这一工具实现快速检索,例如图片,文本等。

sentence-transformer模块提供的样例semantic_search_quora_faiss.py中,就包含了对结合FAISS使用的介绍,这篇博客基本内容就是对这个样例进行搬运,并做简单的注释。

2. 安装FAISS

打开pypi搜索FAISS,出现的第一条FAISS 1.5.1不要安装,那个项目停在两年前不维护了,也没有在py3.8以上的版本上编译,所以直接下载cpu版本或者gpu版本的1.7.1.

3. 下载预训练模型和数据集

在样例中,数据集采用的是quora的重复问句数据集,点此下载。这个数据集原本是用来做句对匹配训练的,每一条中有两个question,以及这两个question是否是重复提问的label,但是在这个应用场景中,语义检索希望把用户输入问句的最相关的问句返回,并不关心重复的问题,于是将question1和question2一视同仁的看待。这个数据集里一共有500k的语句。

关于预训练模型,样例中给出的是在quora上训练的’quora-distilbert-multilingual’,而我使用的是mpnet-base,实验效果也是不错的,预训练模型的下载和使用可以参考上一篇博客

from sentence_transformers import SentenceTransformer, util
import os
import csv
import pickle
import time
import faiss
import numpy as np

4. 获取embedding

由于数据量比较大,不可能每次检索的时候都对库里的所有语句现场编码,所以需要把所有候选语句全都编码完,保存到本地。我在实验时使用3090对全部500k的候选用mpnet-base编码完,总共耗时约5分钟。
首先设置一下embedding相关的参数:

# 下载的数据集的路径
dataset_path = "your_path_to_/quora_duplicate_questions.tsv"
# 在500k的数据集里选择多少条构建候选语料库
max_corpus_size = 100000
# 候选语料保存在本地文件的名称
embedding_cache_path = 'quora-embeddings-{}-size-{}.pkl'.format(model_name.replace('/', '_'), max_corpus_size)
# 预训练模型的编码输出特征维度
embedding_size = 768

然后开始编码,并保存到本地:

#Check if embedding cache path exists
if not os.path.exists(embedding_cache_path):
    # Check if the dataset exists. If not, download and extract
    # Download dataset if needed
    # if not os.path.exists(dataset_path):
    #     print("Download dataset")
    #     util.http_get(url, dataset_path)

    # Get all unique sentences from the file
    corpus_sentences = set()
    with open(dataset_path, encoding='utf8') as fIn:
        reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_MINIMAL)
        for row in reader:
            corpus_sentences.add(row['question1'])
            if len(corpus_sentences) >= max_corpus_size:
                break

            corpus_sentences.add(row['question2'])
            if len(corpus_sentences) >= max_corpus_size:
                break

    corpus_sentences = list(corpus_sentences)
    print("Encode the corpus. This might take a while")
    corpus_embeddings = model.encode(corpus_sentences, show_progress_bar=True, convert_to_numpy=True)

    print("Store file on disc")
    with open(embedding_cache_path, "wb") as fOut:
        pickle.dump({'sentences': corpus_sentences, 'embeddings': corpus_embeddings}, fOut)
else:
    print("Load pre-computed embeddings from disc")
    with open(embedding_cache_path, "rb") as fIn:
        cache_data = pickle.load(fIn)
        corpus_sentences = cache_data['sentences']
        corpus_embeddings = cache_data['embeddings']

然后这个pkl会保存到本地,100k的编码文件大概有300M大小。

5. 创建索引并训练

首先设置一下相关的参数。

# 查询前K个结果
top_k_hits = 10
# 聚类的数量,这个数量一般介于4*sqrt(N) 到 16*sqrt(N),N是语料库的大小
n_clusters = 1024
# 在最相关的多少个簇中搜索答案,这个参数越大查的越全,消耗的时间也就越多
_nprobe = 3

然后创建索引对象,并传入训练数据。

# 创建
quantizer = faiss.IndexFlatIP(embedding_size)
index = faiss.IndexIVFFlat(quantizer, embedding_size, n_clusters, faiss.METRIC_INNER_PRODUCT)
index.nprobe = _nprobe

# 训练
# 因为对向量做点积计算以进行召回,先对所有语料的编码进行normalize
corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1)[:, None]
index.train(corpus_embeddings)
index.add(corpus_embeddings)

6. 应用

训练完成之后就可以应用了。

while True:
    inp_question = input("Please enter a question: ")

    start_time = time.time()
    question_embedding = model.encode(inp_question)

    #FAISS works with inner product (dot product). When we normalize vectors to unit length, inner product is equal to cosine similarity
    question_embedding = question_embedding / np.linalg.norm(question_embedding)
    question_embedding = np.expand_dims(question_embedding, axis=0)

	## 使用FAISS进行检索
    # Search in FAISS. It returns a matrix with distances and corpus ids.
    distances, corpus_ids = index.search(question_embedding, top_k_hits)

    # We extract corpus ids and scores for the first query
    hits = [{'corpus_id': id, 'score': score} for id, score in zip(corpus_ids[0], distances[0])]
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    end_time = time.time()

    print("Input question:", inp_question)
    print("Results (after {:.3f} seconds):".format(end_time-start_time))
    for hit in hits[0:top_k_hits]:
        print("\t{:.3f}\t{}".format(hit['score'], corpus_sentences[hit['corpus_id']]))

	## 为了评估FAISS的效果,下面这部分是为了对照在全局检索的结果中,FAISS成功命中了多少
	## 在实际应用的时候要删除下面的内容
    # Approximate Nearest Neighbor (ANN) is not exact, it might miss entries with high cosine similarity
    # Here, we compute the recall of ANN compared to the exact results
    correct_hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k_hits)[0]
    correct_hits_ids = set([hit['corpus_id'] for hit in correct_hits])

    ann_corpus_ids = set([hit['corpus_id'] for hit in hits])
    if len(ann_corpus_ids) != len(correct_hits_ids):
        print("Approximate Nearest Neighbor returned a different number of results than expected")

    recall = len(ann_corpus_ids.intersection(correct_hits_ids)) / len(correct_hits_ids)
    print("\nApproximate Nearest Neighbor Recall@{}: {:.2f}".format(top_k_hits, recall * 100))

    if recall < 1:
        print("Missing results:")
        for hit in correct_hits[0:top_k_hits]:
            if hit['corpus_id'] not in ann_corpus_ids:
                print("\t{:.3f}\t{}".format(hit['score'], corpus_sentences[hit['corpus_id']]))
    print("\n\n========\n")

以上就是本期全部内容了,如果这篇博客对你有帮助的话,还请一键三连支持一下up主,我们下期再见。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-01 11:55:38  更:2021-09-01 11:57:22 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 20:36:29-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码