IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 34 - Swin-Transformer论文精讲及其PyTorch逐行复现 -> 正文阅读

[人工智能]34 - Swin-Transformer论文精讲及其PyTorch逐行复现

1. 两种方法实现Patch_Embedding

import torch
from torch.nn import functional as F



# method_1 : using unfold to achieve the patch_embedding
# step_1: unfold the image
# step_2: unfold_output@weight
def image2embed_naive(image, patch_size, weight):
	"""
	:param image: [bs,in_channel,height,width]
	:param patch_size:
	:param weight : weight.shape=[patch_depth=in_channel*patch_size*patch_size,model_dim_C]
	:return: patch_embedding,it shape is [batch_size,num_patches,model_dim_C]
	"""

	# patch_depth = in_channel*patch_size*patch_size
	# image_output.shape = [batch_size,num_patch,patch_depth=in_channel*patch_size*patch_size]
	image_output = F.unfold(image, kernel_size=(patch_size, patch_size),
							stride=(patch_size, patch_size)).transpose(-1, -2)

	# change the final_channel dimension from patch_depth to model_dim_C
	patch_embedding = image_output @ weight

	return patch_embedding



# using F.conv2d to achieve the patch_embedding
def image2conv(image, weight, patch_size):
	# image =[batch_size,in_channel,height,width]
	# weight = [out_channels,in_channels,kernel_h,kernel_w]
	conv_output = F.conv2d(image, weight=weight, stride=patch_size)
	bs, oc, oh, ow = conv_output.shape
	patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1,-2)

	return patch_embedding


batch_size = 1
in_channel = 2
out_channel = 5
height = 3
width = 4
input = torch.randn(batch_size, in_channel, height, width)

patch_size = 2

weight1_depth = in_channel * patch_size * patch_size

weight1_model_c = out_channel

weight1 = torch.randn(weight1_depth,weight1_model_c)

weight2_out_channel = weight1_model_c


weight2 = weight1.transpose(0,1).reshape(weight1_model_c,in_channel,patch_size,patch_size)

output1 = image2embed_naive(input, patch_size, weight1)

output2 = image2conv(input, weight2, patch_size)


# flag the check output1 is the same for output2
# if flag is true ,they are the same
flag = torch.isclose(output1,output2)
print(f"flag={flag}")
print(f"output1={output1}")
print(f"output2={output2}")
print(f"output1.shape={output1.shape}")
print(f"output2.shape={output2.shape}")

2. 多头自注意力(Multi_Head_Self_Attention)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-16 11:19:46  更:2022-05-16 11:20:57 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/2 0:25:03-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码