基于Pytorch使用FFT,矩阵乘法,Conv2d计算卷积
目标:计算64*64矩阵X和3*3矩阵H的卷积Y=X*H
第一节:导入库
import torch
import torch.nn as nn
from timeit import Timer
x_n = torch.tensor(torch.randint(0,128,[1,1,64,64]),dtype=torch.float32)
h_n = torch.randn(1,1,3,3)
<ipython-input-1-d7583688b6eb>:6: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
x_n = torch.tensor(torch.randint(0,128,[1,1,64,64]),dtype=torch.float32) # 创建一个四维随机张量,样本数为1,通道数为1,大小为64*64为图像
第二节:定义函数
2.1 使用fft计算卷积
def fft_test():
"""用fft计算卷积"""
x_n_fft = torch.fft.fft2(x_n[0,0])
h_trans = torch.flip(h_n[0,0], [1])
h_trans = torch.flipud(h_trans)
pad = nn.ZeroPad2d(padding=(0, 61, 0, 61))
h_n_pad = pad(h_trans)
h_n_fft = torch.fft.fft2(h_n_pad)
res = x_n_fft.mul(h_n_fft)
res = torch.real(torch.fft.ifft2(res)[2:,2:]).view(1,1,62,62)
return res
2.2 使用矩阵乘法计算卷积
def multi_test():
"""用矩阵乘法计算卷积"""
res = []
for i in range(62):
for j in range(62):
x_n_n = x_n[0,0,i:i+3,j:j+3]
res0 = torch.sum(x_n_n.mul(h_n))
res.append(res0)
res = torch.Tensor(res).view([1,1,62,62])
return res
2.3 使用内置函数计算卷积
def conv2d_test():
"""用内置函数nn.Conv2d计算卷积"""
conv = nn.Conv2d(1,1,3)
conv.weight.data = h_n
res = conv(x_n)
return res
2.4 测试函数运行的时间
def test_time(test_name=""):
"""测试函数的运行时间,输入:需测试的函数名"""
test_import = "from __main__ import " + test_name
test_name = test_name + "()"
test = Timer(test_name,test_import)
time = test.timeit(number=1000)
print("%s is run %.3f ms"%(test_name,time))
第三节:测试
3.1运行时间测试
test_time("fft_test")
test_time("multi_test")
test_time("conv2d_test")
>>> fft_test() is run 0.262 ms
multi_test() is run 61.161 ms
conv2d_test() is run 0.187 ms
结论:矩阵乘法实现卷积最慢,使用FFT计算卷积稍慢于内置函数直接计算卷积。
3.2 每种方法的输出
fft_test()
>>> tensor([[[[ 186.7925, 64.8263, -217.0269, ..., -54.8264, -16.5303,
-57.4855],
[ 105.5789, 103.1986, -45.4856, ..., 108.9532, -11.3397,
-74.9136],
[ 37.4062, -159.4092, -82.7995, ..., -18.7437, 103.2889,
-56.7241],
...,
[ -57.3688, 205.4553, 114.1560, ..., 61.6281, -30.6847,
123.7901],
[ 107.0706, 222.8759, 189.9121, ..., 118.7491, 61.0583,
154.0859],
[ 157.7202, 55.5320, -37.9096, ..., -37.6815, 39.8024,
-123.9654]]]])
multi_test()
>>> tensor([[[[ 186.7925, 64.8263, -217.0269, ..., -54.8264, -16.5303,
-57.4854],
[ 105.5789, 103.1987, -45.4857, ..., 108.9533, -11.3397,
-74.9136],
[ 37.4062, -159.4092, -82.7995, ..., -18.7437, 103.2889,
-56.7241],
...,
[ -57.3689, 205.4553, 114.1560, ..., 61.6282, -30.6847,
123.7902],
[ 107.0706, 222.8759, 189.9121, ..., 118.7491, 61.0583,
154.0859],
[ 157.7202, 55.5320, -37.9096, ..., -37.6815, 39.8024,
-123.9654]]]])
conv2d_test()
>>> tensor([[[[ 187.0928, 65.1265, -216.7267, ..., -54.5261, -16.2301,
-57.1852],
[ 105.8792, 103.4989, -45.1854, ..., 109.2535, -11.0395,
-74.6133],
[ 37.7064, -159.1090, -82.4993, ..., -18.4435, 103.5891,
-56.4239],
...,
[ -57.0686, 205.7555, 114.4562, ..., 61.9284, -30.3845,
124.0904],
[ 107.3709, 223.1761, 190.2123, ..., 119.0493, 61.3585,
154.3862],
[ 158.0205, 55.8323, -37.6093, ..., -37.3813, 40.1026,
-123.6652]]]], grad_fn=<ThnnConv2DBackward>)
结论: FFT计算卷积与矩阵乘法结果一致,内置函数计算卷积结果不稳定,但基本可以认为结果一致。
3.3 运算复杂度分析
import torch
import torch.nn as nn
from timeit import Timer
import big_o
h_n = torch.randn(1,1,3,3)
3.3.1 估计fft时间复杂度
def fft_test(x_n):
"""用fft计算卷积"""
x_n_fft = torch.fft.fft2(x_n[0,0])
h_trans = torch.flip(h_n[0,0], [1])
h_trans = torch.flipud(h_trans)
pad = nn.ZeroPad2d(padding=(0,57, 0, 57))
h_n_pad = pad(h_trans)
h_n_fft = torch.fft.fft2(h_n_pad)
res = x_n_fft.mul(h_n_fft)
res = torch.real(torch.fft.ifft2(res)[2:,2:]).view(1,1,58,58)
return res
根据程序可以发现影响运算复杂度的主要为两次FFT和一次IFFT,设图像矩阵为n*n,卷积核矩阵为3*3;
取最高复杂度为FFT变换的复杂度
n
2
l
o
g
n
n^2logn
n2logn,故运算复杂度估计为
O
(
n
2
l
o
g
n
)
。
O(n^2logn)。
O(n2logn)。
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,64,64]),dtype=torch.float32)
使用big_o模块进行复杂度估计,通过修改矩阵的大小分别得到以下几组结果:
best,other = big_o.big_o(fft_test,positive_int_generator,n_repeats=100)
print(best)
<ipython-input-87-b8f03f88cb0c>:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,520,520]),dtype=torch.float32)
>>>Logarithmic: time = 0.089 + -0.0031*log(n) (sec)
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,220,220]),dtype=torch.float32)
>>>Logarithmic: time = 0.11 + -0.0043*log(n) (sec)
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,120,120]),dtype=torch.float32)
>>>Logarithmic: time = 0.59 + -0.019*log(n) (sec)
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,60,60]),dtype=torch.float32)
>>>Logarithmic: time = 0.089 + -0.0054*log(n) (sec)
通过几组不同大小的图像分析,运算复杂度更接近
O
(
l
o
g
n
)
O(logn)
O(logn)。
3.3.2 估计矩阵乘法时间复杂度
def multi_test(x_n):
"""用矩阵乘法计算卷积"""
res = []
for i in range(62):
for j in range(62):
x_n_n = x_n[0,0,i:i+3,j:j+3]
res0 = torch.sum(x_n_n.mul(h_n))
res.append(res0)
res = torch.Tensor(res).view([1,1,62,62])
return res
- 估计时间复杂度(图像的大小n*n作为参数)
根据程序可以发现影响时间复杂度的为一个双重嵌套循环,设图像矩阵为n*n,卷积核矩阵为3*3;
最外层for循环有
(
n
?
3
)
(n-3)
(n?3),故可估计复杂度为
n
n
n;
内部循环复杂度最高的为一个取矩阵块的操作,可认为是二维的切片操作,从
0
0
0到
n
?
3
n-3
n?3,有
(
(
n
?
3
)
/
2
)
2
((n-3)/2)^2
((n?3)/2)2,故可取复杂度为
n
2
n^2
n2;
忽略其他因素,可认为运算复杂度为
O
(
n
3
)
O(n^3)
O(n3)。 - 程序计算时间复杂度(图像作为参数)
best,other = big_o.big_o(multi_test,positive_int_generator,n_repeats=1)
print(best)
<ipython-input-9-ae97cb9d6fcc>:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,160,160]),dtype=torch.float32)
>>>Cubic: time = 0.43 + -2.4E-17*n^3 (sec)
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,100,100]),dtype=torch.float32)
>>>Cubic: time = 0.2 + -2.4E-17*n^3 (sec)
positive_int_generator = lambda n : torch.tensor(torch.randint(0,128,[1,1,64,64]),dtype=torch.float32)
>>>Cubic: time = 0.062 + 1.9E-17*n^3 (sec)
通过几组不同大小的图像分析,运算复杂度更接近
O
(
n
3
)
O(n^3)
O(n3)。
PS:关于本文章的运算复杂度分析为个人(初学者)观点,如有错误,希望各位大佬指正,谢谢!
|