IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> tensorflow recommenders 系列2:召回模型介绍 -> 正文阅读

[人工智能]tensorflow recommenders 系列2:召回模型介绍

基础教程部分,参考:Recommending movies: retrieval ?|? TensorFlow Recommenders

在召回模型训练阶段涉及到

task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
    movies.batch(5).map(movie_model),k=3
)

其作用有两个,一个是返回定义召回效果评估度量标准FactorizedTopK,另一个是定义了损失函数的计算方法,并对每一个批次返回对应的损失计算结果。

首先介绍损失函数部分Retrieval对应的call方法(Retrieval方法只适合双塔模型),

  def call(self,
           query_embeddings: tf.Tensor,
           candidate_embeddings: tf.Tensor,
           sample_weight: Optional[tf.Tensor] = None,
           candidate_sampling_probability: Optional[tf.Tensor] = None,
           candidate_ids: Optional[tf.Tensor] = None,
           compute_metrics: bool = True) -> tf.Tensor:
    """Computes the task loss and metrics.

    The main argument are pairs of query and candidate embeddings: the first row
    of query_embeddings denotes a query for which the candidate from the first
    row of candidate embeddings was selected by the user.

    The task will try to maximize the affinity of these query, candidate pairs
    while minimizing the affinity between the query and candidates belonging
    to other queries in the batch.

    Args:
      query_embeddings: [num_queries, embedding_dim] tensor of query
        representations.
      candidate_embeddings: [num_queries, embedding_dim] tensor of candidate
        representations.
      sample_weight: [num_queries] tensor of sample weights.
      candidate_sampling_probability: Optional tensor of candidate sampling
        probabilities. When given will be be used to correct the logits to
        reflect the sampling probability of negative candidates.
      candidate_ids: Optional tensor containing candidate ids. When given
        enables removing accidental hits of examples used as negatives. An
        accidental hit is defined as an candidate that is used as an in-batch
        negative but has the same id with the positive candidate.
      compute_metrics: Whether to compute metrics. Set this to False
        during training for faster training.

    Returns:
      loss: Tensor of loss values.
    """
    #此处将两个向量的乘积结果作为用户和实体向量的相似度值,参考https://www.cnblogs.com/daniel-D/p/3244718.html
    scores = tf.linalg.matmul(
        query_embeddings, candidate_embeddings, transpose_b=True)

    num_queries = tf.shape(scores)[0]
    num_candidates = tf.shape(scores)[1]
    #根据结果构造对角函数,作为预期结果矩阵
    labels = tf.eye(num_queries, num_candidates)

    metric_update_ops = []
    if compute_metrics:
      if self._factorized_metrics:
        metric_update_ops.append(
            self._factorized_metrics.update_state(query_embeddings,
                                                  candidate_embeddings))
      if self._batch_metrics:
        metric_update_ops.extend([
            batch_metric.update_state(labels, scores)
            for batch_metric in self._batch_metrics
        ])

    if self._temperature is not None:
      #
      scores = scores / self._temperature

    if candidate_sampling_probability is not None:
      scores = layers.loss.SamplingProbablityCorrection()(
          scores, candidate_sampling_probability)

    if candidate_ids is not None:
      scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)

    if self._num_hard_negatives is not None:
      scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(
          scores,
          labels)

    loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)

    if not metric_update_ops:
      return loss

    with tf.control_dependencies(metric_update_ops):
      return tf.identity(loss)
  • 关于score和labels的计算已经在代码中进行了注解
  • 关于温度淬火法参数的解释self._temperature参考:深度学习中的temperature parameter是什么 - 知乎
  • 关于损失函数,默认的损失函数为类别交叉熵损失函数(对应ont-hot表示法),如果是integer表示法,可以换为
    SparseCategoricalCrossentropy
tf.keras.losses.CategoricalCrossentropy(
    from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

?对应的调用样例参考:

>> y_true = [[0, 1, 0], [0, 0, 1]]
  >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
  >>> # Using 'auto'/'sum_over_batch_size' reduction type.
  >>> cce = tf.keras.losses.CategoricalCrossentropy()
  >>> cce(y_true, y_pred).numpy()
  1.177

  >>> # Calling with 'sample_weight'.
  >>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy()
  0.814

  >>> # Using 'sum' reduction type.
  >>> cce = tf.keras.losses.CategoricalCrossentropy(
  ...     reduction=tf.keras.losses.Reduction.SUM)
  >>> cce(y_true, y_pred).numpy()
  2.354

关于度量标准部分?,此处重点理解TopKCategoricalAccuracy

查看源码解释既可明白,其是对每一个用户的topk推荐结果(是数值topk而非位置)是否包含了预期的结果,统计所有用户的对应情况占比。

  >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
  >>> m.update_state([[0, 0, 1], [0, 1, 0]],
  ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
  >>> m.result().numpy()
#1/2
  0.5

  >>> m.reset_state()
  >>> m.update_state([[0, 0, 1], [0, 1, 0]],
  ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
  ...                sample_weight=[0.7, 0.3])
  >>> m.result().numpy()
#0*0.7+1*0.3
  0.3

在召回模型使用阶段,

?如果候选实体比较少的时候,可以使用暴力求解法:

index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(
    movies.batch(100).map(lambda title: (title, model.movie_model(title))))

# Get some recommendations.
_, titles = index(np.array(["42"]))
print(f"Top 3 recommendations for user 42: {titles[0, :3]}")

如果候选实体比较大的时候,可以使用近似ScaNN方法,tfrs.layers.factorized_top_k.ScaNN。关于ScaNN部分,也可以通过类似milvas等向量数据库的查询方法得以实现

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-23 10:50:39  更:2022-04-23 10:51:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 23:30:43-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码