IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【推荐算法 学习与复现】-- 深度学习系列 -- DeepFM -> 正文阅读

[人工智能]【推荐算法 学习与复现】-- 深度学习系列 -- DeepFM

?整体内容和前面相似,FM部分二阶交叉部分数学原理参考如下:

简单理解FM公式 - 知乎 (zhihu.com)icon-default.png?t=M3K6https://zhuanlan.zhihu.com/p/354994307

class FM(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim

        self.w0 = nn.Parameter(torch.zeros([1,]))
        self.w1 = nn.Parameter(torch.rand([num_embeddings, 1]))
        self.w2 = nn.Parameter(torch.rand([num_embeddings, embedding_dim]))
    
    def forward(self, x):
        first_order = torch.mm(x, self.w1)
        second_order = 0.5 * torch.sum(
            torch.pow(torch.mm(x, self.w2), 2) - torch.mm(torch.pow(x, 2), torch.pow(self.w2, 2)),
            dim=1,
            keepdim=True
        )

        return self.w0 + first_order + second_order
        

class DNN(nn.Module):
    def __init__(self, hidden_units):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(in_features, out_features, bias=True) for in_features, out_features in zip(hidden_units[:-1], hidden_units[1:])
        ])

    def forward(self, x):
        for layer in self.layers:
            x = F.relu(layer(x))
        return x


class DeepFM(nn.Module):
    def __init__(self, features_info, hidden_units, embedding_dim):
        super().__init__()

        # 解析特征信息
        self.dense_features, self.sparse_features, self.sparse_features_nunique = features_info

        # 解析拿到所有 数值型 和 稀疏型特征信息
        self.__dense_features_nums = len(self.dense_features)
        self.__sparse_features_nums = len(self.sparse_features)

        # embedding 
        self.embeddings = nn.ModuleDict({
            "embed_" + key : nn.Embedding(num_embeds, embedding_dim)
                for key, num_embeds in self.sparse_features_nunique.items()
        })

        stack_dim = self.__dense_features_nums + self.__sparse_features_nums * embedding_dim
        hidden_units.insert(0, stack_dim)

        self.fm = FM(stack_dim, embedding_dim)

        self.dnn = DNN(hidden_units)

        self.dnn_last_linear = nn.Linear(hidden_units[-1], 1, bias=False)

    def forward(self, x):

        # 从输入x中单独拿出 sparse_input 和 dense_input 
        dense_inputs, sparse_inputs = x[:, :self.__dense_features_nums], x[:, self.__dense_features_nums:]
        sparse_inputs = sparse_inputs.long()

        embedding_feas = [self.embeddings["embed_" + key](sparse_inputs[:, idx]) for idx, key in enumerate(self.sparse_features)]
        embedding_feas = torch.cat(embedding_feas, dim=-1)

        input_feas = torch.cat([embedding_feas, dense_inputs], dim=-1)

        fm = self.fm(input_feas)
        dnn = self.dnn_last_linear(self.dnn(input_feas))

        return F.sigmoid(fm + dnn)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-30 08:42:53  更:2022-04-30 08:44:09 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/6 17:38:47-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码