IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 25 - 转置卷积的原理与实现 -> 正文阅读

[人工智能]25 - 转置卷积的原理与实现

1. 基于卷积核扩充实现卷积

1.1 图解

在这里插入图片描述

1.2 代码

import torch
import torch.nn as nn
from torch.nn import functional as F


# 默认操作为二维操作,stride=1,padding=0
def conv2d_padding_kernel(input, kernel):
	input_h, input_w = input.shape
	kernel_h, kernel_w = kernel.shape
	input_vector = torch.flatten(input)
	output_w = input_w - kernel_w + 1
	output_h = input_h - kernel_h + 1
	kernel_matrix_h = output_w * output_h
	kernel_matrix_w = input_w * input_h
	kernel_matrix = torch.zeros(kernel_matrix_h, kernel_matrix_w)
	row_index = 0
	for i in range(0, output_h, 1):
		for j in range(0, output_w, 1):
			padded_kernel = F.pad(kernel, (i, input_w - kernel_w - i, j, input_h - kernel_h - j))
			padded_kernel = torch.flatten(padded_kernel)
			kernel_matrix[row_index] = padded_kernel
			row_index += 1
	output_matrix = kernel_matrix @ input_vector
	output_matrix = output_matrix.reshape((output_h, output_w))
	return output_matrix


input = torch.arange(16, dtype=torch.float).reshape((4, 4))
kernel = torch.arange(9, dtype=torch.float).reshape((3, 3))
output = conv2d_padding_kernel(input, kernel)
print(f"input={input}")
print(f"kernel={kernel}")
print(f"output={output}")
  • 结果:
input=tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
kernel=tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
output=tensor([[258., 402.],
        [294., 438.]])

2. 基于卷积核转置实现转置卷积

import torch
import torch.nn as nn
from torch.nn import functional as F


# 默认操作为二维操作,stride=1,padding=0
def conv2d_padding_kernel(input, kernel):
	input_h, input_w = input.shape
	kernel_h, kernel_w = kernel.shape
	input_vector = torch.flatten(input)
	output_w = input_w - kernel_w + 1
	output_h = input_h - kernel_h + 1
	kernel_matrix_h = output_w * output_h
	kernel_matrix_w = input_w * input_h
	kernel_matrix = torch.zeros(kernel_matrix_h, kernel_matrix_w)
	row_index = 0
	for i in range(0, output_h, 1):
		for j in range(0, output_w, 1):
			padded_kernel = F.pad(kernel, (i, input_w - kernel_w - i, j, input_h - kernel_h - j))
			padded_kernel = torch.flatten(padded_kernel)
			kernel_matrix[row_index] = padded_kernel
			row_index += 1
	output_matrix = kernel_matrix @ input_vector
	output_matrix = output_matrix.reshape((output_h, output_w))
	return output_matrix.T


def transposed_conv2d_padding_kernel(input, kernel, output):
	input_h, input_w = input.shape
	kernel_h, kernel_w = kernel.shape
	input_vector = torch.flatten(input)
	output_w = input_w - kernel_w + 1
	output_h = input_h - kernel_h + 1
	kernel_matrix_h = output_w * output_h
	kernel_matrix_w = input_w * input_h
	kernel_matrix = torch.zeros(kernel_matrix_h, kernel_matrix_w)
	row_index = 0
	for i in range(0, output_h, 1):
		for j in range(0, output_w, 1):
			padded_kernel = F.pad(kernel, (i, input_w - kernel_w - i, j, input_h - kernel_h - j))
			padded_kernel = torch.flatten(padded_kernel)
			kernel_matrix[row_index] = padded_kernel
			row_index += 1
	conv_ouput_matrix = kernel_matrix @ input_vector
	# 转置卷积体现在这里!!!
	output_matrix = kernel_matrix.transpose(-1, -2) @ conv_ouput_matrix
	return output_matrix.reshape((input_h, input_w))


input = torch.arange(16, dtype=torch.float).reshape((4, 4))
kernel = torch.arange(9, dtype=torch.float).reshape((3, 3))
output = conv2d_padding_kernel(input, kernel)
input_transpose = transposed_conv2d_padding_kernel(input=input, kernel=kernel, output=output)
pytorch_transpose = F.conv_transpose2d(output.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0)).squeeze()
print(f"input={input}")
print(f"kernel={kernel}")
print(f"output={output}")
print(f"input_transpose={input_transpose}")
print(f"pytorch_transpose={pytorch_transpose}")
print(f"torch.isclose(input_transpose,pytorch_transpose)\n={torch.isclose(input_transpose,pytorch_transpose)}")
input=tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
kernel=tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
output=tensor([[258., 294.],
        [402., 438.]])
input_transpose=tensor([[   0.,  258.,  810.,  588.],
        [ 774., 2316., 3708., 2346.],
        [2754., 6492., 7884., 4542.],
        [2412., 5442., 6282., 3504.]])
pytorch_transpose=tensor([[   0.,  258.,  810.,  588.],
        [ 774., 2316., 3708., 2346.],
        [2754., 6492., 7884., 4542.],
        [2412., 5442., 6282., 3504.]])
torch.isclose(input_transpose,pytorch_transpose)
=tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-07 22:41:41  更:2022-04-07 22:42:04 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 4:47:40-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码