IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 机器学习常用特征相似度或距离度量【Pytorch实现】 -> 正文阅读

[人工智能]机器学习常用特征相似度或距离度量【Pytorch实现】

余弦相似度、特征距离、KL散度计算

import torch

feat1 = torch.randn((3, 4))
feat2 = torch.randn((3, 4))

# ==========cosine相似度===============
a_norm = torch.linalg.norm(feat1, dim=1)
b_norm = torch.linalg.norm(feat2, dim=1)
cos = ((feat1 * feat2).sum(dim=(-2, -1)) / (a_norm * b_norm)).mean()
cos = 0.5 * cos + 0.5   # [-1, 1]  --> [0, 1]

# ==========矩阵范数===============
lfro = torch.linalg.norm(feat1 - feat2, dim=-1).mean()	# F范数
l1 = torch.linalg.norm(feat1 - feat2, dim=-1, ord=1).mean()
linf_max = torch.linalg.norm(feat1 - feat2, dim=-1, ord=float('inf')).mean()
linf_min = torch.linalg.norm(feat1 - feat2, dim=-1, ord=-float('inf')).mean()

# ==========KL散度===============
kl = torch.nn.functional.kl_div(feat1.softmax(dim=-1).log(), feat2.softmax(dim=-1), reduction='batchmean')	# 'mean'
kl2 = torch.nn.functional.kl_div(feat1.softmax(dim=-1).log(), feat2.softmax(dim=-1), reduction='sum')

若为三维矩阵:batch * channel * feature,则

import torch

feat1 = torch.randn((2, 3, 4))
feat2 = torch.randn((2, 3, 4))

# ==========cosine相似度===============
a_norm = torch.linalg.norm(feat1, dim=(-2, -1))
b_norm = torch.linalg.norm(feat2, dim=(-2, -1))
cos = ((feat1 * feat2).sum(dim=(-2, -1)) / (a_norm * b_norm)).mean()
cos = 0.5 * cos + 0.5   # [-1, 1]  --> [0, 1]

# ==========矩阵范数===============
lfro = torch.linalg.norm(feat1 - feat2, dim=(-2, -1)).mean()	# F范数
l1 = torch.linalg.norm(feat1 - feat2, dim=(-2, -1), ord=1).mean()
linf_max = torch.linalg.norm(feat1 - feat2, dim=(-2, -1), ord=float('inf')).mean()
linf_min = torch.linalg.norm(feat1 - feat2, dim=(-2, -1), ord=-float('inf')).mean()

# ==========KL散度===============
kl = torch.nn.functional.kl_div(feat1.softmax(dim=-1).log(), feat2.softmax(dim=-1), reduction='batchmean')	# 'mean'
kl2 = torch.nn.functional.kl_div(feat1.softmax(dim=-1).log(), feat2.softmax(dim=-1), reduction='sum')

Reference

Pytorch范数文档
向量与矩阵的范数
向量余弦相似度
KL散度

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-09-03 11:53:35  更:2021-09-03 11:55:29 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 19:54:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码