问题背景
在使用torch_dct 时报错,经调研,是torch版本过高,dct中调用的旧的fft相关函数已经更新。
探索过程
参考[2]中说的对应关系如下:
旧版 | 新版 |
---|
torch.rfft(input, signal_ndim=2, normalized=False, onesided=False) | torch.fft.fft() | torch.rfft(input, signal_ndim=2, normalized=False, onesided=True) | torch.fft.rfft() | torch.irfft(input, signal_ndim=2, normalized=False, onesided=False) | torch.fft.ifft2() | torch.irfft(input, signal_ndim=2, normalized=False, onesided=True) | torch.fft.irfft2() |
感觉博客[2]讲的比较乱,而且它只提了对于2D,而我现在需要处理1D的信号,去pytorch官网[3]才看明白函数是咋用的。
torch_dct 中对于2D和3D的DCT都是调用1D的DCT函数实现的,而1D的DCT中调用了1D的FFT,所以我需要针对我自己碰到的问题另寻方法进行改进,
解决方案
这里参考[1]中提到2D fft的解决方案,我给出对于2D torch_dct 的解决方案:在_dct.py 的开头添加下面这句话:
try:
from torch import irfft
from torch import rfft
except ImportError:
def rfft(x, d):
t = torch.fft.fft(x, dim = (-d))
r = torch.stack((t.real, t.imag), -1)
return r
def irfft(x, d):
t = torch.fft.ifft(torch.complex(x[:,:,0], x[:,:,1]), dim = (-d))
return t.real
同时在调用到的地方对函数名从torch.rfft 和torch.irfft 中进行修改:
Vc = rfft(v, 1)
v = irfft(V, 1)
结果验证
和cv2的dct和idct进行对比,计算得到的结果相同: 代码如下:
import os
import cv2
import time
import torch
import torch_dct
import numpy as np
bgr_img = np.random.random((24,24))
bgr_tensor = torch.tensor(bgr_img)
dct_img = cv2.dct(bgr_img.astype(np.float32))
dct_img_tensor = torch.tensor(dct_img)
dct_tensor = torch_dct.dct_2d(bgr_tensor, 'ortho')
print(dct_img_tensor - dct_tensor)
idct_img = cv2.idct(dct_img)
idct_img_tensor = torch.tensor(idct_img)
idct_tensor = torch_dct.idct_2d(dct_tensor, 'ortho')
print(idct_img_tensor - idct_tensor)
参考链接
[1] 旧版pytorch中torch.rfft和irfft在新版本中的对应 [2] 解决报错:AttributeError: module ‘torch‘ has no attribute ‘irfft‘ [3] https://pytorch.org/docs/stable/fft.html
|