# 1
from thop import profile
x = torch.randn(1, 3, 256, 256)
flops, params = profile(self.modelG, inputs=(x,))
print('flops is %.2fM' % (flops/1e6)) ## 打印计算量
print('params is %.2fM' % (params/1e6)) ## 打印参数量
# 2
num = 0
for param in self.modelG.parameters():
if param.requires_grad:
num += param.numel()
print('param is %.2fM' % (num / 1e6))
# 3
total = sum([param.nelement() for param in self.modelG.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
# 4
from ptflops import get_model_complexity_info
flops, params = get_model_complexity_info(self.modelG, (3, 256, 256), as_strings=True,
print_per_layer_stat=False) # 不用写batch_size大小,默认batch_size=1
print('Flops: ' + flops)
print('Params: ' + params)
exit(-1)
|