python代码实现global top-k pooling:
import torch
from torch import nn
class TopKPool(nn.Module):
def __init__(self, k):
super(TopKPool, self).__init__()
self.k = k
def forward(self, x):
b, c, _, _ = x.shape
x = x.view(b, c, -1)
topkv, _ = x.topk(self.k, dim= -1)
out = topkv.mean(dim= -1)
return out
if __name__ == "__main__":
aa = torch.randn(2, 3, 5, 5)
print("aa: ", aa.size())
print("----------------------------")
print("aa: ", aa)
bb = nn.MaxPool2d(kernel_size=5, stride=2)
c = bb(aa)
print("c: ", c)
print("c.shape", c.shape)
print("----------------------------")
topkmaxpool = TopKPool(k=2)
dd = topkmaxpool(aa)
print("dd: ", dd)
print("dd.shape", dd.shape)
|