1. 基于卷积核扩充实现卷积
1.1 图解
1.2 代码
import torch
import torch.nn as nn
from torch.nn import functional as F
def conv2d_padding_kernel(input, kernel):
input_h, input_w = input.shape
kernel_h, kernel_w = kernel.shape
input_vector = torch.flatten(input)
output_w = input_w - kernel_w + 1
output_h = input_h - kernel_h + 1
kernel_matrix_h = output_w * output_h
kernel_matrix_w = input_w * input_h
kernel_matrix = torch.zeros(kernel_matrix_h, kernel_matrix_w)
row_index = 0
for i in range(0, output_h, 1):
for j in range(0, output_w, 1):
padded_kernel = F.pad(kernel, (i, input_w - kernel_w - i, j, input_h - kernel_h - j))
padded_kernel = torch.flatten(padded_kernel)
kernel_matrix[row_index] = padded_kernel
row_index += 1
output_matrix = kernel_matrix @ input_vector
output_matrix = output_matrix.reshape((output_h, output_w))
return output_matrix
input = torch.arange(16, dtype=torch.float).reshape((4, 4))
kernel = torch.arange(9, dtype=torch.float).reshape((3, 3))
output = conv2d_padding_kernel(input, kernel)
print(f"input={input}")
print(f"kernel={kernel}")
print(f"output={output}")
input=tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])
kernel=tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]])
output=tensor([[258., 402.],
[294., 438.]])
2. 基于卷积核转置实现转置卷积
import torch
import torch.nn as nn
from torch.nn import functional as F
def conv2d_padding_kernel(input, kernel):
input_h, input_w = input.shape
kernel_h, kernel_w = kernel.shape
input_vector = torch.flatten(input)
output_w = input_w - kernel_w + 1
output_h = input_h - kernel_h + 1
kernel_matrix_h = output_w * output_h
kernel_matrix_w = input_w * input_h
kernel_matrix = torch.zeros(kernel_matrix_h, kernel_matrix_w)
row_index = 0
for i in range(0, output_h, 1):
for j in range(0, output_w, 1):
padded_kernel = F.pad(kernel, (i, input_w - kernel_w - i, j, input_h - kernel_h - j))
padded_kernel = torch.flatten(padded_kernel)
kernel_matrix[row_index] = padded_kernel
row_index += 1
output_matrix = kernel_matrix @ input_vector
output_matrix = output_matrix.reshape((output_h, output_w))
return output_matrix.T
def transposed_conv2d_padding_kernel(input, kernel, output):
input_h, input_w = input.shape
kernel_h, kernel_w = kernel.shape
input_vector = torch.flatten(input)
output_w = input_w - kernel_w + 1
output_h = input_h - kernel_h + 1
kernel_matrix_h = output_w * output_h
kernel_matrix_w = input_w * input_h
kernel_matrix = torch.zeros(kernel_matrix_h, kernel_matrix_w)
row_index = 0
for i in range(0, output_h, 1):
for j in range(0, output_w, 1):
padded_kernel = F.pad(kernel, (i, input_w - kernel_w - i, j, input_h - kernel_h - j))
padded_kernel = torch.flatten(padded_kernel)
kernel_matrix[row_index] = padded_kernel
row_index += 1
conv_ouput_matrix = kernel_matrix @ input_vector
output_matrix = kernel_matrix.transpose(-1, -2) @ conv_ouput_matrix
return output_matrix.reshape((input_h, input_w))
input = torch.arange(16, dtype=torch.float).reshape((4, 4))
kernel = torch.arange(9, dtype=torch.float).reshape((3, 3))
output = conv2d_padding_kernel(input, kernel)
input_transpose = transposed_conv2d_padding_kernel(input=input, kernel=kernel, output=output)
pytorch_transpose = F.conv_transpose2d(output.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0)).squeeze()
print(f"input={input}")
print(f"kernel={kernel}")
print(f"output={output}")
print(f"input_transpose={input_transpose}")
print(f"pytorch_transpose={pytorch_transpose}")
print(f"torch.isclose(input_transpose,pytorch_transpose)\n={torch.isclose(input_transpose,pytorch_transpose)}")
input=tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])
kernel=tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]])
output=tensor([[258., 294.],
[402., 438.]])
input_transpose=tensor([[ 0., 258., 810., 588.],
[ 774., 2316., 3708., 2346.],
[2754., 6492., 7884., 4542.],
[2412., 5442., 6282., 3504.]])
pytorch_transpose=tensor([[ 0., 258., 810., 588.],
[ 774., 2316., 3708., 2346.],
[2754., 6492., 7884., 4542.],
[2412., 5442., 6282., 3504.]])
torch.isclose(input_transpose,pytorch_transpose)
=tensor([[True, True, True, True],
[True, True, True, True],
[True, True, True, True],
[True, True, True, True]])
|