1、安装 thop
注意:输入必须是四维的
pip install thop
------------------------------------------------------------------------------
from nets.yolo4 import YoloBody
from thop import profile
model = YoloBody(3,20)
input = torch.randn(1, 3, 420, 420)
flops, params = profile(model, inputs=(input,))
print('flops' ,flops)
print('params',params)
或者
from torchvision.models import resnet50
from thop import profile
checkpoints = '模型path'
model = torch.load(checkpoints)
model_name = 'yolov3 cut asff'
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),verbose=True)
print("%s | %.2f | %.2f" % (model, params / (1000 ** 2), flops / (1000 ** 3)))
2、torchstat
使用torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。
from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d
model = resnet50()
input = torch.randn(1, 3, 224, 224)
stat(model, (3, 224, 224))
参考
|