IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> tensorflow recommenders 系列3:双塔召回模型 -> 正文阅读

[Python知识库]tensorflow recommenders 系列3:双塔召回模型

要求:tensorflow>2.6.0

from typing import Dict, Text

import keras.layers
import tensorflow as tf
from keras import Model

import tensorflow_recommenders as tfrs

###############################模型编写###############################

class RecUserModel(tf.keras.Model):
    '''
    用户属性支持:用户姓名、用户点击数据(最近前三)、点赞数据(最近前三)、收藏数据(最贱前三)
    '''
    def __init__(self,user_tag_vocabulary):
        super().__init__()
        self.userTags_vector = tf.keras.layers.StringLookup()
        self.userTags_vector.adapt(user_tag_vocabulary)
        self.userTags_embedding = keras.layers.Embedding(input_dim=len(self.userTags_vector.get_vocabulary()),output_dim=4)

    #@tf.function(input_signature=({"userTags":tf.TensorSpec(shape=(None,None), dtype=tf.dtypes.string, name="userTags"),"age":tf.TensorSpec(shape=(None,None), dtype=tf.dtypes.float32, name="age")},))
    def call(self, inputs):
        tags_lookup = self.userTags_vector(inputs.get("userTags"))
        user_embedding = tf.math.reduce_sum(self.userTags_embedding(tags_lookup),axis=-2,keepdims=False)
        return user_embedding+inputs.get("age")

class RecItemModel(tf.keras.Model):
    def __init__(self, item_tag_vocabulary):
        super().__init__()
        self.userTags_vector =tf.keras.layers.StringLookup()
        self.userTags_vector.adapt(item_tag_vocabulary)
        self.userTags_embedding = keras.layers.Embedding(input_dim=len(self.userTags_vector.get_vocabulary()),
                                                         output_dim=4)
    #@tf.function(input_signature=({"itemTags":tf.TensorSpec(shape=(None,None), dtype=tf.dtypes.string, name="itemTags")},))
    def call(self, inputs):
        tags_lookup = self.userTags_vector(inputs.get("itemTags"))
        user_embedding = tf.math.reduce_sum(self.userTags_embedding(tags_lookup), axis=1, keepdims=False)
        return user_embedding

class ItemRecModel(tfrs.Model):
    # We derive from a custom base class to help reduce boilerplate. Under the hood,
    # these are still plain Keras Models.

    def __init__(
            self,
            user_model: tf.keras.Model,
            item_model: tf.keras.Model,
            task: tfrs.tasks.Retrieval):
        super().__init__()

        # Set up user and movie representations.
        self.user_model = user_model
        self.movie_model = item_model

        # Set up a retrieval task.
        self.task = task

    def compute_loss(self, features: Dict[Text, Dict], training=False) -> tf.Tensor:
        # Define how the loss is computed.
        user_embeddings = self.user_model(features["user_features"])
        #
        movie_embeddings = self.movie_model(features["item_features"])

        return self.task(user_embeddings, movie_embeddings)

###################################数据处理逻辑########################################
'''
数据样例:
age	userTag	itemTag
13	a,b,c	c,d,e
14	e,h,g	n,k,m
15	e,f,g	n,k,m
16	e,d,g	验,过,m
14	e,n,g	n,e,m
11	e,m,g	n,k,m
19	e,s,g	n,c,m
'''
def str_json(message):
'''
解决标签长度不一致问题
'''
    return {
        "user_features": {
            "age": [tf.strings.to_number(message[0], tf.float32)],
            "usertags": tf.pad(tf.strings.split(message[1], sep=','),[[0, 5]])[:5]
        },
        "item_features": {
            "itemtags": tf.pad(tf.strings.split(message[2], sep=','),[[0, 5]])[:5]
        }
    }
original_data =tf.data.TextLineDataset(['G:/git_alg/prepare_project/recommenders/data/test.csv'],num_parallel_reads=2)
map_result = original_data.skip(1).map(lambda x:tf.strings.split(x,sep='\t'))\
    .map(str_json)


userTag_vocabulary =map_result.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x["user_features"]["userTags"])).unique()
itemTag_vocabulary = map_result.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x["item_features"]["itemTags"])).unique()


##########################模型训练######################################
user_model = RecUserModel(userTag_vocabulary)
item_model = RecItemModel(itemTag_vocabulary)
#
task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
    map_result.map(lambda x:x["item_features"]).batch(5).map(item_model),k=3
)
)


item_rec_model = ItemRecModel(user_model,item_model,task)

item_rec_model.compile(optimizer=tf.keras.optimizers.Adagrad(0.5))

model_file ="G:/git_alg/prepare_project/recommenders/saved_model/"
callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_file,
                                              save_weights_only=True,
                                              verbose=1)


item_rec_model.fit(map_result.batch(2), epochs=3,callbacks=[callback])
#
tf.keras.models.save_model(user_model,'G:/git_alg/prepare_project/recommenders/user_model')
tf.keras.models.save_model(item_model,'G:/git_alg/prepare_project/recommenders/item_model')
###############################检索样例#######################



user_model =  tf.keras.models.load_model('G:/git_alg/prepare_project/recommenders/user_model',compile=False)
item_model =  tf.keras.models.load_model('G:/git_alg/prepare_project/recommenders/item_model',compile=False)
index = tfrs.layers.factorized_top_k.BruteForce(user_model,k=2)
index.index_from_dataset(
    map_result.map(lambda x:x["item_features"]).batch(3).map(lambda x: item_model(x)   ))


# # Get some recommendations.注意维度匹配
titles = index({"age":tf.constant([[11]],dtype=tf.float32),"userTags":tf.constant([["n","k"]],dtype=tf.string)})
print(f"Top 3 recommendations for user 42: {titles}")

具体的可执行代码:GitHub - guoyandan/rec_model: python 版本的推荐模型(数据加载、多特征处理、多特征组合关联推荐、数据标签处理)

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-04-28 11:48:54  更:2022-04-28 11:49:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 8:27:42-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计