用之前的去噪图镇文
一、体会
本菜鸡本科毕设在FPGA上搞过图像滤波等算法,研究生期间虽然搞的是基于深度学习的图形学,但是主干网络用的还是卷积… 感觉自己代码能力还可以,基础还行,参赛之前还是比较自信的: 觉着看几篇顶会去噪的文章,复现借鉴一下应该能取得一个不错的结果,但是-------大概1000+人参赛,一多半没有提交的或者只提交个baseline,本菜最终100+ 额还没结束 明天结束了估计排名快接近200了 实在卷不动了 主要有以下三点问题:
- 与佬们能力差别还是不小,可能与研究方向也有关,毕竟不是专业的
- Money is all my need?? 经历了学校服务器排队人数爆满、维修,小组里的机器也排不上队(我一个参加比赛的也不好意思和别人抢- -),就想调个参,很难,即便排上队 – 跑的时候batch_size都调的很小才能跑
- 我–工具人… 论文直接被拒 好家伙审稿人提的什么鬼意见 - -
虽然知道自己菜,但还是希望尝试一下。 吐槽结束,进入正题::::
二、收获
谈一下收获,虽然困难挺多,但是收获也很多
- 看了几篇cv顶会的去噪文章,了解并尝试了cv算法中low-level的方向
- 尝试复现了两篇顶会,效果并没有baseline好,差不太多 - - (可能复现的不太对,毕竟只是借用思想不是完全拷贝) 最终魔改了一篇别的论文.
- 从dataloder、网络框架、网络初始化、训练策略到最后的损失函数等等,第一次完整的写了一个深度学习的项目(以前都是拿别人代码框架改改),遇到很多坑,也学到了许多新的知识点
三、经验分享(部分源码展示与注释)
3.1 输入
图片是要切片的,一整张图太大了,网络稍大点,32G的显卡也会爆显存 把一张图分块为多个图,伪代码如下:
tmp['imgs'] = data['imgs'][:, :, a:b, c:d]
tmp['gts'] = data['gts'][:, :, a:b, c:d]
model.set_input(tmp)
3.2 网络
我的网络主要借鉴的思想:
1.不直接学习端到端的像素值,而是学习噪声(网络更容易拟合?) 2.使用通道可分离的卷积,适当增加通道数(显存太小,跑起来速度很慢) 3.尝试增加卷积核大小(显存太小,跑起来速度很慢)
(比赛有模型大小限制)–增大通道和卷积核都会增加显存的使用,设备不行,故只增了通道数。具体的实现细节如下:
纯纯的Unet baseline修改而来
class Unet2(nn.Module):
def __init__(self, dim=4):
super(Unet2, self).__init__()
self.dims = [32, 64, 128, 256, 512]
self.ks = [3, 3, 3, 3, 3]
self.dims_up = self.dims[::-1]
self.ks_up = self.ks[-2::-1]
self.first_block = Block2(dim, self.dims[0], self.ks[0])
self.first_pool = nn.MaxPool2d(kernel_size=2)
for i, dim_in in enumerate(self.dims[:-2]):
dim_out = self.dims[i+1]
setattr(self, 'Block{}'.format(i), Block2(dim_in, dim_out, k=self.ks[i+1]))
setattr(self, 'pool{}'.format(i), nn.MaxPool2d(kernel_size=2))
self.conv_mid = Block2(self.dims[-2], self.dims[-1], self.ks[-1])
for i, dim_in in enumerate(self.dims_up[:-1]):
dim_out = self.dims_up[i+1]
setattr(self, 'ConvTrans{}'.format(i), nn.ConvTranspose2d(dim_in, dim_out, 2, stride=2, bias=True))
setattr(self, 'up_Block{}'.format(i), Block2(dim_in, dim_out, k=self.ks_up[i]))
self.last_conv = nn.Conv2d(self.dims[0], dim, 1, bias=True)
def forward(self, x):
n, c, h, w = x.shape
h_pad = 32 - h % 32 if not h % 32 == 0 else 0
w_pad = 32 - w % 32 if not w % 32 == 0 else 0
padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
list_pools = []
x_bk = x
x = self.first_block(padded_image)
list_pools.append(x)
x = self.first_pool(x)
for i, dim_in in enumerate(self.dims[:-2]):
x = getattr(self, 'Block{}'.format(i))(x)
list_pools.append(x)
x = getattr(self, 'pool{}'.format(i))(x)
x = self.conv_mid(x)
for i, dim_in in enumerate(self.dims_up[:-1]):
x = getattr(self, 'ConvTrans{}'.format(i))(x)
x = torch.cat([x, list_pools.pop()], 1)
x = getattr(self, 'up_Block{}'.format(i))(x)
x = self.last_conv(x)
out = x[:, :, :h, :w] + x_bk
return out
class Block2(nn.Module):
def __init__(self, dim_in, dim_out, k=3):
super(Block2, self).__init__()
self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=k, padding=k // 2, padding_mode='zeros', bias=True)
self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=k, padding=k // 2, padding_mode='zeros', bias=True)
def forward(self, x):
x = self.conv1(x)
x = self.leaky_relu(x)
x = self.conv2(x)
x = self.leaky_relu(x)
return x
def leaky_relu(self, x, a=0.2):
out = torch.max(a * x, x)
return out
我使用的网络 魔改ConvNet
class Our(nn.Module):
def __init__(self, dim=4):
super(Our, self).__init__()
self.dims = [128, 256, 512, 1024]
self.ks = [3, 3, 3, 3]
self.dims_up = self.dims[::-1]
self.ks_up = self.ks[-2::-1]
self.first_block = Block(dim, self.dims[0], self.ks[0])
self.first_pool = nn.MaxPool2d(kernel_size=2)
for i, dim_in in enumerate(self.dims[:-2]):
dim_out = self.dims[i+1]
setattr(self, 'Block{}'.format(i), Block(dim_in, dim_out, k=self.ks[i+1]))
setattr(self, 'pool{}'.format(i), nn.MaxPool2d(kernel_size=2))
self.conv_mid = Block(self.dims[-2], self.dims[-1], self.ks[-1])
for i, dim_in in enumerate(self.dims_up[:-1]):
dim_out = self.dims_up[i+1]
setattr(self, 'ConvTrans{}'.format(i), nn.ConvTranspose2d(dim_in, dim_out, 2, stride=2))
setattr(self, 'up_Block{}'.format(i), Block(dim_in, dim_out, k=self.ks_up[i]))
self.last_ln = nn.LayerNorm(self.dims[0], eps=1e-6)
self.last_conv = nn.Linear(self.dims[0], dim)
def forward(self, x):
n, c, h, w = x.shape
h_pad = 32 - h % 32 if not h % 32 == 0 else 0
w_pad = 32 - w % 32 if not w % 32 == 0 else 0
padded_image = F.pad(x, (0, w_pad, 0, h_pad), 'replicate')
list_pools = []
x_bk = x
x = self.first_block(padded_image)
list_pools.append(x)
x = self.first_pool(x)
for i, dim_in in enumerate(self.dims[:-2]):
x = getattr(self, 'Block{}'.format(i))(x)
list_pools.append(x)
x = getattr(self, 'pool{}'.format(i))(x)
x = self.conv_mid(x)
for i, dim_in in enumerate(self.dims_up[:-1]):
x = getattr(self, 'ConvTrans{}'.format(i))(x)
x = torch.cat([x, list_pools.pop()], 1)
x = getattr(self, 'up_Block{}'.format(i))(x)
x = x.permute(0, 2, 3, 1).contiguous()
x = self.last_ln(x)
x = self.last_conv(x)
x = x.permute(0, 3, 1, 2).contiguous()
out = x[:, :, :h, :w] + x_bk
return out
class Block(nn.Module):
def __init__(self, dim_in, dim_out, k=9):
super(Block, self).__init__()
self.conv = nn.Conv2d(dim_in, dim_in, groups=dim_in, kernel_size=k, padding=k // 2)
self.ln = nn.LayerNorm(dim_in,eps=1e-6)
self.conv1x1up = nn.Linear(dim_in, dim_in * 2)
self.act = nn.GELU()
self.conv1x1dn = nn.Linear(dim_in * 2, dim_out)
self.w = nn.Parameter(torch.zeros(1))
self.res_conv = nn.Conv2d(dim_in, dim_out, 1)
def forward(self, x):
identity = x
x = self.conv(x)
x = x.permute(0, 2, 3, 1).contiguous()
x = self.ln(x)
x = self.conv1x1up(x)
x = self.act(x)
x = self.conv1x1dn(x)
x = x.permute(0, 3, 1, 2).contiguous()
x = x * self.w
x = x + self.res_conv(identity)
return x
3.3 损失函数
loss = torch.nn.L1Loss()
实测了一下,还是L1效果好啊 其它L2、SSIM之类的花里胡哨的效果并不理想 (毕竟是炼丹,可能只是不适合我的网络)
3.4 传统滤波方法
哈、我还试了一下传统的去噪,顺便使用纯python写了一个双边滤波(参考我以前matlab的代码),不得不说,还是深度学习yyds!
def bilateral_filter(img):
r = 20
sigma_space = 15.0
sigma_color = 10.0
w_space = np.zeros((2*r + 1, 2*r + 1))
for i in range(-r-1, r):
for j in range(-r-1, r):
tmp = i * i + j * j
w_space[i + r+1, j + r+1] = np.exp(-float(tmp) / (2 * sigma_space * sigma_space))
w_color = np.zeros((1, 256))
for i in range(256):
w_color[0, i] = np.exp(-float(i * i) / (2 * sigma_color * sigma_color))
height, width, channel = img.shape
dst_img = img.copy()
for h in range(r, height - r):
for w in range(r, height - r):
for c in range(channel):
p_c = img[h, w, c]
p_win = img[h-r:h+r+1, w-r:w+r+1, c]
c_w = np.abs(p_win - p_c).astype(int)
c_w = w_color[0, c_w]
w_tmp = w_space * c_w
p_sum = p_win * w_tmp
p_sum = np.sum(p_sum) / np.sum(w_tmp)
dst_img[h, w, c] = p_sum
return dst_img
四、主要参考链接
- https://zhuanlan.zhihu.com/p/455913104 (ConvNeXt: A ConvNet for the 2020s)
- https://zhuanlan.zhihu.com/p/349644858 (如何白嫖GPU)
- https://blog.csdn.net/u011447962/article/details/123510680 (CVPR 2022 | RepLKNet)
- https://github.com/gbstack/CVPR-2022-papers#SG (CVPR2022 Papers (Papers/Codes/Demos))
|