IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> CV+Deep Learning——网络架构Pytorch复现系列——basenets(BackBones)(一) -> 正文阅读

[人工智能]CV+Deep Learning——网络架构Pytorch复现系列——basenets(BackBones)(一)

引言此系列重点在于复现计算机视觉(分类、目标检测、语义分割)中深度学习各个经典的网络模型,以便初学者使用(深入浅出)!

代码都运行无误!!

首先复现深度学习经典网络模型(basenet)(就是家喻户晓的Backbone,但是我会对Backbone做一些改动,所以这个系列就不叫Backbone,叫basenet)这些网络大都是分类的经典网络(1.,2.,3.,4.,5.,6.,7.),目标检测的Backbone(8.,9.),有:

1.LeNet5(√)

2.VGG(√)

3.AlexNet(√)

4.ResNet(√)

5.GoogLeNet

5.MobileNet

6.ShuffleNet

7.EfficientNet

8.VovNet

9.DarkNet

...

注意:

a) 完整代码上传至我的github

https://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-masterhttps://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-masterb) 编译环境设置为 (其实不用这个编译环境,你会调bug也行!)

python == 3.9.12
torch == 1.11.0+cu113
torchvision== 0.11.0+cu113
torchaudio== 0.12.0+cu113
pycocotools == 2.0.4
numpy
Cython
matplotlib
opencv-python
tqdm
thop

c) 分类数据集使用ImageNet或CIFAR10,其目录 (coco和voc用于目标检测和语义分割现在暂时用不到):

dataset path: /data/

data
|
|----coco----|----coco2017
|
|----cifar
|
|----ImageNet----|----ILSVRC2012
|
|----VOCdevkit


coco2017 path: /data/coco/coco2017
coco2017
|
|
|----annotations
|----train2017
|----test2017
|----val2017


voc path: /data/VOCdevkit
|
|               |----Annotations
|               |----ImageSets
|----VOC2007----|----JPEGImages
|               |----SegmentationClass
|               |----SegmentationObject
|
|
|               |----Annotations
|               |----ImageSets
|----VOC2012----|----JPEGImages
|               |----SegmentationClass
|               |----SegmentationObject


ILSVRC2012 path : /data/ImageNet/ILSVRC2012
|
|----train
|
|----val


cifar path: /data/cifar
|
|----cifar-10-batches-py
|
|----cifar-10-python.tar.gz

d) 使用了amp混精度使gpu加速,若不知如何使用可参考如下链接:

如何使用Pytorch让网络模型加速训练?(autocast与GradScaler)https://blog.csdn.net/XiaoyYidiaodiao/article/details/124854343?spm=1001.2014.3001.5502

所以需要在网络模型的forward函数前加入 @autocast(),并且又因为使用了1.4以上版本的torch,必须修改ReLu(inplace=False),Dropout(inplace=False),等等有inplace都设置为False。

e) 由于LeNet5、VGG16、AlexNet使用了全连接层不能修改图像的size,所以这些网络架构在图像预处理时图像的size就必须固定

f) 项目文件结构

使用的OS (Ubuntu 20.04),当然windows下也能运行,我运行过。有的文件夹用不上,先别管,我之后会讲。

project path: /data/PycharmProject/

Simple-CV-master path: /data/PycharmProject/Simple-CV-Pytorch-master
|
|----checkpoints ( resnet50-19c8e357.pth \COCO_ResNet50.pth[RetinaNet]\ VOC_ResNet50.pth[RetinaNet] )
|
|            |----cifar.py ( null, I just use torchvision.datasets.ImageFolder )
|            |----CIAR_labels.txt
|            |----coco.py
|            |----coco_eval.py
|            |----coco_labels.txt
|----data----|----__init__.py
|            |----config.py ( path )
|            |----imagenet.py ( null, I just use torchvision.datasets.ImageFolder )
|            |----ImageNet_labels.txt
|            |----voc0712.py
|            |----voc_eval.py
|            |----voc_labels.txt
|                                     |----crash_helmet.jpg
|----images----|----classification----|----sunflower.jpg
|              |                      |----photocopier.jpg
|              |                      |----automobile.jpg
|              |
|              |----detection----|----000001.jpg
|                                |----000001.xml
|                                |----000002.jpg
|                                |----000002.xml
|                                |----000003.jpg
|                                |----000003.xml
|
|----log(XXX[ detection or classification ]_XXX[  train or test or eval ].info.log)
|
|              |----__init__.py
|              |
|              |              |----__init.py
|              |----anchor----|----RetinaNetAnchors.py
|              |
|              |               |----lenet5.py
|              |               |----alexnet.py
|              |----basenet----|----vgg.py
|              |               |----resnet.py
|              |
|              |                 |----DarkNetBackbone.py
|              |----backbones----|----__init__.py ( Don't finish writing )
|              |                 |----ResNetBackbone.py
|              |                 |----VovNetBackbone.py
|              |
|              |
|              |
|----models----|----heads----|----__init.py
|              |             |----RetinaNetHeads.py
|              |
|              |              |----RetinaNetLoss.py
|              |----losses----|----__init.py
|              |
|              |             |----FPN.py
|              |----necks----|----__init__.py
|              |             |-----FPN.txt
|              |
|              |----RetinaNet.py
|
|----results ( eg: detection ( VOC or COCO AP ) )
|
|----tensorboard ( Loss visualization )
|
|----tools                       |----eval.py
|         |----classification----|----train.py
|         |                      |----test.py
|         |
|         |
|         |
|         |                 |----eval_coco.py
|         |                 |----eval_voc.py
|         |----detection----|----test.py
|                           |----train.py
|
|
|             |----AverageMeter.py
|             |----BBoxTransform.py
|             |----ClipBoxes.py
|             |----Sampler.py
|             |----iou.py
|----utils----|----__init__.py
|             |----accuracy.py
|             |----augmentations.py
|             |----collate.py
|             |----get_logger.py
|             |----nms.py
|             |----path.py
|
|----FolderOrganization.txt
|
|----main.py
|
|----README.md
|
|----requirements.txt

1.LeNet5(size: 32 * 32 * 3)

?图 1.

如图 1.还原代码

加入nn.BatchNorm2d(),使其精度上升,当然为了完全复现,你们可以忽略掉nn.BatchNorm2d(),将其从代码中删除。

可根据数据集类别自行调整最后一层连接层的输出

from torch import nn
from torch.cuda.amp import autocast


class lenet5(nn.Module):
    # cifar: 10, ImageNet: 1000
    def __init__(self, num_classes=1000, init_weights=False):
        super(lenet5, self).__init__()
        self.num_classes = num_classes
        self.layers = nn.Sequential(
            # input:32 * 32 * 3 -> 28 * 28 * 6
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            # 28 * 28 * 6 -> 14 * 14 * 6
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            # 14 * 14 * 6 -> 10 * 10 * 16
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            # 10 * 10 * 16 -> 5 * 5 * 16
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Linear(120, 84))
        self.classifier = nn.Linear(84, self.num_classes)

        if init_weights:
            self._initialize_weights()

    @autocast()
    def forward(self, x):
        x = self.layers(x)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

2.AlexNet (Size: 224 * 224* 3)

图 2.

如图 2.,若不是特别清楚,可参考下图 3.

图 3.

将图 3. 转成图 4. ,这是因为之前的AlexNet是放在两张显卡(当年的计算力是不行的)上跑,现在的计算力能跟上了,可放在一张GPU上跑。

?图 4.

可根据数据集类别自行调整最后一层连接层的输出

加入nn.BatchNorm2d(),使其精度上升,当然为了完全复现,你们可以忽略掉nn.BatchNorm2d(),将其从代码中删除。

import torch.nn as nn
from torch.cuda.amp import autocast


class alexnet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(alexnet, self).__init__()
        self.layers = nn.Sequential(
            # input: 224 * 224 * 3 -> 55 * 55 * (48*2)
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            # 55 * 55 * (48*2) -> 27 * 27 * (48*2)
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 27 * 27 * (48*2) -> 27 * 27 * (128*2)
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # 27 * 27 * (128*2) -> 13 * 13 * (128*2)
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 13 * 13 * (128*2) -> 13 * 13 * (192*2)
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            # 13 * 13 * (192*2) -> 13 * 13 * (192*2)
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            # 13 * 13 * (192*2) -> 13 * 13 * (128*2)
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # 13 * 13 * (128*2) -> 6 * 6 * (128*2)
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(6 * 6 * 128 * 2, 2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048, 2048),
            nn.ReLU()
        )
        self.classifier = nn.Linear(2048, num_classes)
        if init_weights:
            self._initialize_weights()

    @autocast()
    def forward(self, x):
        x = self.layers(x)
        x = self.fc(x)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

3.VGG (Size: 224 * 224* 3 )

图 5.

如图 5.复现绿框中框出的网络架构,还原代码

我是看不下去精度那么差,所以我就把nn.BatchNorm2d(i)添加进去,并且做了迁移学习。

可根据数据集类别自行调整最后一层连接层的输出

import torch
from torch import nn
from utils.path import CheckPoints
from torch.cuda.amp import autocast

__all__ = [
    'vgg11',
    'vgg13',
    'vgg16',
    'vgg19',
]
# if your network is limited, you can download them, and put them into CheckPoints(my Project:Simple-CV-Pytorch-master/checkpoints/).
model_urls = {
    # 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg11': '{}/vgg11-bbd30ac9.pth'.format(CheckPoints),
    # 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg13': '{}/vgg13-c768596a.pth'.format(CheckPoints),
    # 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg16': '{}/vgg16-397923af.pth'.format(CheckPoints),
    # 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg19': '{}/vgg19-dcbb9e9d.pth'.format(CheckPoints)

}


def vgg_(arch, num_classes, pretrained, init_weights=False, **kwargs):
    cfg = cfgs["vgg" + arch]
    features = make_features(cfg)
    model = vgg(num_classes=num_classes, features=features, init_weights=init_weights, **kwargs)
    # if you're training for the first time, no pretrained is required!
    if pretrained:
        pretrained_models = torch.load(model_urls["vgg" + arch])
        # transfer learning
        # if you want to train your own dataset
        if arch == '11':
            del pretrained_models['features.8.weight']
            del pretrained_models['features.11.weight']
            del pretrained_models['features.16.weight']
        elif arch == '13':
            del pretrained_models['features.7.weight']
            del pretrained_models['features.10.weight']
            del pretrained_models['features.15.weight']
            del pretrained_models['features.17.weight']
            del pretrained_models['features.22.weight']
        elif arch == '16':
            del pretrained_models['features.7.weight']
            del pretrained_models['features.10.weight']
            del pretrained_models['features.14.weight']
            del pretrained_models['features.17.weight']
            del pretrained_models['features.21.weight']
            del pretrained_models['features.24.weight']
            del pretrained_models['features.28.weight']
        elif arch == '19':
            del pretrained_models['features.7.weight']
            del pretrained_models['features.10.weight']
            del pretrained_models['features.14.weight']
            del pretrained_models['features.21.weight']
            del pretrained_models['features.23.weight']
            del pretrained_models['features.28.weight']
            del pretrained_models['features.34.weight']
        else:
            raise ValueError("Pretrained: unsupported VGG depth")
        model.load_state_dict(pretrained_models, strict=False)
    return model


def vgg11(num_classes, pretrained=False, init_weights=False, **kwargs):
    return vgg_('11', num_classes, pretrained, init_weights, **kwargs)


def vgg13(num_classes, pretrained=False, init_weights=False, **kwargs):
    return vgg_('13', num_classes, pretrained, init_weights, **kwargs)


def vgg16(num_classes, pretrained=False, init_weights=False, **kwargs):
    return vgg_('16', num_classes, pretrained, init_weights, **kwargs)


def vgg19(num_classes, pretrained=False, init_weights=False, **kwargs):
    return vgg_('19', num_classes, pretrained, init_weights, **kwargs)


class vgg(nn.Module):
    # cifar: 10, ImageNet: 1000
    def __init__(self, features, num_classes=1000, init_weights=False):
        super(vgg, self).__init__()
        self.num_classes = num_classes
        self.features = features
        self.fc = nn.Sequential(

            nn.Flatten(),
            nn.Linear(7 * 7 * 512, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
        )

        self.classifier = nn.Linear(4096, self.num_classes)
        if init_weights:
            self._initialize_weights()

    @autocast()
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)


def make_features(cfgs: list):
    layers = []
    in_channels = 3
    for i in cfgs:
        if i == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, i, kernel_size=3, stride=1, padding=1, bias=False)
            layers += [conv2d, nn.BatchNorm2d(i), nn.ReLU()]
            in_channels = i
    return nn.Sequential(*layers)


cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],

}

4.ResNet

图 6.

如图 6. 复现网络架构(ResNet18,ResNet34,ResNet50,ResNet101,ResNet152),还原代码

首先来看每个block如何复现?如图18-layer, 34-layer由下图 7. 绿框表示,50-layer, 101-layer, 152-layer由下图 7. 蓝框表示;

block: 18-layer, 34-layer

# 18-layer, 34-layer
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

?block: 50-layer, 101-layer, 152-layer

# 50-layer, 101-layer, 152-layer
class Bottleneck(nn.Module):
    """
    self.conv1(kernel_size=1,stride=2)
    self.conv2(kernel_size=3,stride=1)

    to

    self.conv1(kernel_size=1,stride=1)
    self.conv2(kernel_size=3,stride=2)
    """
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                               stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3,
                               stride=stride, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels * self.expansion, kernel__size=1,
                               stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample

    def forward(self, x):
        identity = x

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

整个ResNet模型的还原,先还原第一层卷积和最大池化层

class ResNet(nn.Module):
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channels = 64


        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU()

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

之后的layer层,如图 8. 代码与图中的表示模块的对等关系

conv2_x -> self.layer1, 
conv3_x -> self.layer2, 
conv4_x -> self.layer3, 
conv5_x -> self.layer4

...
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
...

50-layer, 101-layer, 152-layer 复现虚线部分;18-layer,34-layer也是这样,就不展示了。

图 8.

   def _make_layer(self, block, channels, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channels, out_channels=channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels * block.expansion)

            )
         ...

之后调用ResNet模型,选取合适的layer(18, 34, 50, 101, 152)

def resnet18(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, include_top)


def resnet34(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('34', BasicBlock, [3, 4, 6, 3], num_classes, pretrained, include_top)


def resnet50(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('50', Bottleneck, [3, 4, 6, 3], num_classes, pretrained, include_top)


def resnet101(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('101', Bottleneck, [3, 4, 23, 3], num_classes, pretrained, include_top)


def resnet152(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('152', Bottleneck, [3, 8, 36, 3], num_classes, pretrained, include_top)

完整代码

import torch
import torch.nn as nn
from utils.path import CheckPoints
from torch.cuda.amp import autocast

__all__ = [
    'resnet18',
    'resnet34',
    'resnet50',
    'resnet101',
    'resnet152'
]
# if your network is limited, you can download them, and put them into CheckPoints(my Project:Simple-CV-Pytorch-master/checkpoints/).
model_urls = {
    # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet18': '{}/resnet18-5c106cde.pth'.format(CheckPoints),
    # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet34': '{}/resnet34-333f7ec4.pth'.format(CheckPoints),
    # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet50': '{}/resnet50-19c8e357.pth'.format(CheckPoints),
    # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet101': '{}/resnet101-5d3b4d8f.pth'.format(CheckPoints),
    # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnet152': '{}/resnet152-b121ed2d.pth'.format(CheckPoints)

}


def resnet_(arch, block, block_num, num_classes, pretrained, include_top, **kwargs):
    model = resnet(block=block, blocks_num=block_num, num_classes=num_classes, include_top=include_top, **kwargs)
    # if you're training for the first time, no pretrained is required!
    if pretrained:
       # if you want to use cpu, you should modify map_loaction=torch.device("cpu")
        pretrained_models = torch.load(model_urls["resnet" + arch], map_location=torch.device("cuda:0"))
        # transfer learning
        # if you want to train your own dataset
        # del pretrained_models['module.classifier.bias']
        model.load_state_dict(pretrained_models, strict=False)
    return model


# 18-layer, 34-layer
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    @autocast()
    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


# 50-layer, 101-layer, 152-layer
class Bottleneck(nn.Module):
    """
    self.conv1(kernel_size=1,stride=2)
    self.conv2(kernel_size=3,stride=1)

    to

    self.conv1(kernel_size=1,stride=1)
    self.conv2(kernel_size=3,stride=2)

    acc: up 0.5%
    """
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                               stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3,
                               stride=stride, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels * self.expansion, kernel__size=1,
                               stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample

    @autocast()
    def forward(self, x):
        identity = x

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class resnet(nn.Module):
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(resnet, self).__init__()
        self.include_top = include_top
        self.in_channels = 64
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.flatten = nn.Flatten()
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channels, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channels, out_channels=channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels * block.expansion)

            )
        layers = []
        layers.append(block(in_channels=self.in_channels, out_channels=channels, downsample=downsample, stride=stride))
        self.in_channels = channels * block.expansion

        for _ in range(1, block_num):
            layers.append(
                block(in_channels=self.in_channels, out_channels=channels))

        return nn.Sequential(*layers)

    @autocast()
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = self.flatten(x)
            x = self.fc(x)
        return x


def resnet18(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, include_top)


def resnet34(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('34', BasicBlock, [3, 4, 6, 3], num_classes, pretrained, include_top)


def resnet50(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('50', Bottleneck, [3, 4, 6, 3], num_classes, pretrained, include_top)


def resnet101(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('101', Bottleneck, [3, 4, 23, 3], num_classes, pretrained, include_top)


def resnet152(num_classes=1000, pretrained=False, include_top=True):
    return resnet_('152', Bottleneck, [3, 8, 36, 3], num_classes, pretrained, include_top)

一些配置文件

utils/path.py

import os.path
import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(BASE_DIR)
# Gets home dir cross platform
# "/data/"

MyName = "PycharmProject"
Folder = "Simple-CV-Pytorch-master"

# Path to store checkpoint model
CheckPoints = 'checkpoints'
CheckPoints = os.path.join(BASE_DIR, MyName, Folder, CheckPoints)

# Path to store tensorboard load
tensorboard_log = 'tensorboard'
tensorboard_log = os.path.join(BASE_DIR, MyName, Folder, tensorboard_log)

# Path to save log
log = 'log'
log = os.path.join(BASE_DIR, MyName, Folder, log)

# Path to save classification train log
classification_train_log = 'classification_train'

# Path to save classification test log
classification_test_log = 'classification_test'

# Path to save classification eval log
classification_eval_log = 'classification_eval'

# Classification evaluate model path
classification_evaluate = None


# Images classification path
image_cls = 'automobile.jpg'
images_cls_path = 'images/classification'
images_cls_path = os.path.join(BASE_DIR, MyName, Folder, images_cls_path, image_cls)

# Data
DATAPATH = BASE_DIR

# ImageNet/ILSVRC2012
ImageNet = "ImageNet/ILSVRC2012"
ImageNet_Train_path = os.path.join(DATAPATH, ImageNet, 'train')
ImageNet_Eval_path = os.path.join(DATAPATH, ImageNet, 'val')

# CIFAR10
CIFAR = 'cifar'
CIFAR_path = os.path.join(DATAPATH, CIFAR)

data/config.py

from utils import path

# Path to save log
log = path.log

# Path to save classification train log
classification_train_log = path.classification_train_log

# Path to save classification test log
classification_test_log = path.classification_test_log

# Path to save classification eval log
classification_eval_log = path.classification_eval_log

# Path to store checkpoint model
checkpoint_path = path.CheckPoints

# Classification evaluate model path
classification_evaluate = path.classification_evaluate

# Classification test images
images_cls_root = path.images_cls_path

# Path to save tensorboard
tensorboard_log = path.tensorboard_log

训练代码

tools/classification/train.py

import os
import logging
import argparse
import warnings

warnings.filterwarnings('ignore')

import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

import time
import torch
from data import *
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
from torchvision import transforms
from utils.accuracy import accuracy
from torch.utils.data import DataLoader
from utils.get_logger import get_logger
from models.basenets.lenet5 import lenet5
from models.basenets.alexnet import alexnet
from utils.AverageMeter import AverageMeter
from torch.cuda.amp import autocast, GradScaler
from models.basenets.vgg import vgg11, vgg13, vgg16, vgg19
from models.basenets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')
    parser.add_mutually_exclusive_group()
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR',
                        choices=['ImageNet', 'CIFAR'],
                        help='ImageNet, CIFAR')
    parser.add_argument('--dataset_root',
                        type=str,
                        default=CIFAR_ROOT,
                        choices=[ImageNet_Train_ROOT, CIFAR_ROOT],
                        help='Dataset root directory path')
    parser.add_argument('--basenet',
                        type=str,
                        default='lenet',
                        choices=['resnet', 'vgg', 'lenet', 'alexnet'],
                        help='Pretrained base model')
    parser.add_argument('--depth',
                        type=int,
                        default=5,
                        help='BaseNet depth, including: LeNet of 5, AlexNet of 0, VGG of 11, 13, 16, 19, ResNet of 18, 34, 50, 101, 152')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='Batch size for training')
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='Checkpoint state_dict file to resume training from')
    parser.add_argument('--num_workers',
                        type=int,
                        default=8,
                        help='Number of workers user in dataloading')
    parser.add_argument('--cuda',
                        type=str,
                        default=True,
                        help='Use CUDA to train model')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='Momentum value for optim')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.1,
                        help='Gamma update for SGD')
    parser.add_argument('--accumulation_steps',
                        type=int,
                        default=1,
                        help='Gradient acumulation steps')
    parser.add_argument('--save_folder',
                        type=str,
                        default=config.checkpoint_path,
                        help='Directory for saving checkpoint models')
    parser.add_argument('--tensorboard',
                        type=str,
                        default=False,
                        help='Use tensorboard for loss visualization')
    parser.add_argument('--log_folder',
                        type=str,
                        default=config.log,
                        help='Log Folder')
    parser.add_argument('--log_name',
                        type=str,
                        default=config.classification_train_log,
                        help='Log Name')
    parser.add_argument('--tensorboard_log',
                        type=str,
                        default=config.tensorboard_log,
                        help='Use tensorboard for loss visualization')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-2,
                        help='learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=30,
                        help='Number of epochs')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-4,
                        help='weight decay')
    parser.add_argument('--milestones',
                        type=list,
                        default=[15, 20, 30],
                        help='Milestones')
    parser.add_argument('--num_classes',
                        type=int,
                        default=10,
                        help='the number classes, like ImageNet:1000, cifar:10')
    parser.add_argument('--image_size',
                        type=int,
                        default=32,
                        help='image size, like ImageNet:224, cifar:32')
    parser.add_argument('--pretrained',
                        type=str,
                        default=True,
                        help='Models was pretrained')
    parser.add_argument('--init_weights',
                        type=str,
                        default=False,
                        help='Init Weights')

    return parser.parse_args()


args = parse_args()

# 1. Log
get_logger(args.log_folder, args.log_name)
logger = logging.getLogger(args.log_name)

# 2. Torch choose cuda or cpu
if torch.cuda.is_available():
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    if not args.cuda:
        print("WARNING: It looks like you have a CUDA device, but you aren't using it" +
              "\n You can set the parameter of cuda to True.")
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)


def train():
    # 3. Create SummaryWriter
    if args.tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        # tensorboard  loss
        writer = SummaryWriter(args.tensorboard_log)
    # vgg16, alexnet and lenet5 need to resize image_size, because of fc.
    if args.basenet == 'vgg' or args.basenet == 'alexnet':
        args.image_size = 224
    elif args.basenet == 'lenet':
        args.image_size = 32

    # 4. Ready dataset
    if args.dataset == 'ImageNet':
        if args.dataset_root == CIFAR_ROOT:
            raise ValueError('Must specify dataset_root if specifying dataset ImageNet2012.')

        elif os.path.exists(ImageNet_Train_ROOT) is None:
            raise ValueError("WARNING: Using default ImageNet2012 dataset_root because " +
                             "--dataset_root was not specified.")

        dataset = torchvision.datasets.ImageFolder(
            root=args.dataset_root,
            transform=torchvision.transforms.Compose([
                transforms.Resize((args.image_size,
                                   args.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ]))

    elif args.dataset == 'CIFAR':
        if args.dataset_root == ImageNet_Train_ROOT:
            raise ValueError('Must specify dataset_root if specifying dataset CIFAR10.')

        elif args.dataset_root is None:
            raise ValueError("Must provide --dataset_root when training on CIFAR10.")

        dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=True,
                                               transform=torchvision.transforms.Compose([
                                                   transforms.Resize((args.image_size,
                                                                      args.image_size)),
                                                   torchvision.transforms.ToTensor()]))
    else:
        raise ValueError('Dataset type not understood (must be ImageNet or CIFAR), exiting.')

    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size,
                                             shuffle=True, num_workers=args.num_workers,
                                             pin_memory=False, generator=torch.Generator(device='cuda'))

    top1 = AverageMeter()
    top5 = AverageMeter()
    losses = AverageMeter()

    # 5. Define train model

    # Unfortunately, Lenet5 and Alexnet don't provide pretrianed Model.
    if args.basenet == 'lenet':
        if args.depth == 5:
            model = lenet5(num_classes=args.num_classes,
                           init_weights=args.init_weights)
        else:
            raise ValueError('Unsupported LeNet depth!')

    elif args.basenet == 'alexnet':
        model = alexnet(num_classes=args.num_classes,
                        init_weights=args.init_weights)

    elif args.basenet == 'vgg':
        if args.depth == 11:
            model = vgg11(pretrained=args.pretrained,
                          num_classes=args.num_classes,
                          init_weights=args.init_weights)
        elif args.depth == 13:
            model = vgg13(pretrained=args.pretrained,
                          num_classes=args.num_classes,
                          init_weights=args.init_weights)
        elif args.depth == 16:
            model = vgg16(pretrained=args.pretrained,
                          num_classes=args.num_classes,
                          init_weights=args.init_weights)
        elif args.depth == 19:
            model = vgg19(pretrained=args.pretrained,
                          num_classes=args.num_classes,
                          init_weights=args.init_weights)
        else:
            raise ValueError('Unsupported VGG depth!')
    # Unfortunately for my resnet, there is no set init_weight, because I'm going to set object detection algorithm
    elif args.basenet == 'resnet':
        if args.depth == 18:
            model = resnet18(pretrained=args.pretrained,
                             num_classes=args.num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=args.pretrained,
                             num_classes=args.num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=args.pretrained,
                             num_classes=args.num_classes)  # False means the models was not trained
        elif args.depth == 101:
            model = resnet101(pretrained=args.pretrained,
                              num_classes=args.num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=args.pretrained,
                              num_classes=args.num_classes)
        else:
            raise ValueError('Unsupported ResNet depth!')

    else:
        raise ValueError('Unsupported model type!')

    if args.cuda:
        if torch.cuda.is_available():
            model = model.cuda()
            model = torch.nn.DataParallel(model).cuda()
    else:
        model = torch.nn.DataParallel(model)

    # 6. Loading weights
    if args.resume:
        other, ext = os.path.splitext(args.resume)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            model_load = os.path.join(args.save_folder, args.resume)
            model.load_state_dict(torch.load(model_load))
        else:
            print('Sorry only .pth and .pkl files supported.')
    if args.init_weights:
        # initialize newly added models' weights with xavier method
        if args.basenet == 'resnet':
            print("There is no set init_weight, because I'm going to set object detection algorithm.")
        else:
            print("Initializing weights...")
    else:
        print("Not Initializing weights...")
    if args.pretrained:
        if args.basenet == 'lenet' or args.basenet == 'alexnet':
            print("There is no available pretrained model on the website. ")
        else:
            print("Models was pretrained...")
    else:
        print("Pretrained models is False...")

    model.train()

    iteration = 0

    # 7. Optimizer
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=args.gamma)
    scaler = GradScaler()

    # 8. Length
    iter_size = len(dataset) // args.batch_size
    print("len(dataset): {}, iter_size: {}".format(len(dataset), iter_size))
    logger.info(f"args - {args}")
    t0 = time.time()

    # 9. Create batch iterator
    for epoch in range(args.epochs):
        t1 = time.time()
        torch.cuda.empty_cache()
        # 10. Load train data
        for data in dataloader:
            iteration += 1
            images, targets = data
            # 11. Backward
            optimizer.zero_grad()
            if args.cuda:
                images, targets = images.cuda(), targets.cuda()
                criterion = criterion.cuda()
            # 12. Forward
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
                loss = loss / args.accumulation_steps

            if args.tensorboard:
                writer.add_scalar("train_classification_loss", loss.item(), iteration)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # 13. Measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))
            losses.update(loss.item(), images.size(0))

            if iteration % 100 == 0:
                logger.info(
                    f"- epoch: {epoch},  iteration: {iteration}, lr: {optimizer.param_groups[0]['lr']}, "
                    f"top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, "
                    f"loss: {loss.item():.3f}, (losses.avg): {losses.avg:3f} "
                )

        scheduler.step(losses.avg)

        t2 = time.time()
        h_time = (t2 - t1) // 3600
        m_time = ((t2 - t1) % 3600) // 60
        s_time = ((t2 - t1) % 3600) % 60
        print("epoch {} is finished, and the time is {}h{}min{}s".format(epoch, int(h_time), int(m_time), int(s_time)))

        # 14. Save train model
        if epoch != 0 and epoch % 10 == 0:
            print('Saving state, iter:', epoch)
            torch.save(model.state_dict(),
                       args.save_folder + '/' + args.dataset +
                       '_' + args.basenet + str(args.depth) + '_' + repr(epoch) + '.pth')
        torch.save(model.state_dict(),
                   args.save_folder + '/' + args.dataset + "_" + args.basenet + str(args.depth) + '.pth')

    if args.tensorboard:
        writer.close()

    t3 = time.time()
    h = (t3 - t0) // 3600
    m = ((t3 - t0) % 3600) // 60
    s = ((t3 - t0) % 3600) % 60
    print("The Finished Time is {}h{}m{}s".format(int(h), int(m), int(s)))
    return top1.avg, top5.avg, losses.avg


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    logger.info("Program started")
    top1, top5, loss = train()
    print("top1 acc: {}, top5 acc: {}, loss:{}".format(top1, top5, loss))
    logger.info("Done!")

测试代码

tools/classification/test.py

import logging
import os
import argparse
import warnings

warnings.filterwarnings('ignore')

import sys

BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)

import time
from data import *
from PIL import Image
import torch.nn.parallel
from torchvision import transforms
from utils.get_logger import get_logger
from models.basenets.lenet5 import lenet5
from models.basenets.alexnet import alexnet
from models.basenets.vgg import vgg11, vgg13, vgg16, vgg19
from models.basenets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152


def parse_args():
    parser = argparse.ArgumentParser(description='PyTorch Classification Testing')
    parser.add_mutually_exclusive_group()
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR',
                        choices=['ImageNet', 'CIFAR'],
                        help='ImageNet,  CIFAR')
    parser.add_argument('--images_root',
                        type=str,
                        default=config.images_cls_root,
                        help='Dataset root directory path')
    parser.add_argument('--basenet',
                        type=str,
                        default='alexnet',
                        choices=['resnet', 'vgg', 'lenet', 'alexnet'],
                        help='Pretrained base model')
    parser.add_argument('--depth',
                        type=int,
                        default=0,
                        help='BaseNet depth, including: LeNet of 5, AlexNet of 0, VGG of 11, 13, 16, 19, ResNet of 18, 34, 50, 101, 152')
    parser.add_argument('--evaluate',
                        type=str,
                        default=config.classification_evaluate,
                        help='Checkpoint state_dict file to evaluate training from')
    parser.add_argument('--save_folder',
                        type=str,
                        default=config.checkpoint_path,
                        help='Directory for saving checkpoint models')
    parser.add_argument('--log_folder',
                        type=str,
                        default=config.log,
                        help='Log Folder')
    parser.add_argument('--log_name',
                        type=str,
                        default=config.classification_test_log,
                        help='Log Name')
    parser.add_argument('--cuda',
                        type=str,
                        default=True,
                        help='Use CUDA to train model')
    parser.add_argument('--num_classes',
                        type=int,
                        default=10,
                        help='the number classes, like ImageNet:1000, cifar:10')
    parser.add_argument('--image_size',
                        type=int,
                        default=32,
                        help='image size, like ImageNet:224, cifar:32')
    parser.add_argument('--pretrained',
                        type=str,
                        default=False,
                        help='Models was pretrained')

    return parser.parse_args()


args = parse_args()

# 1. Torch choose cuda or cpu
if torch.cuda.is_available():
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    if not args.cuda:
        print("WARNING: It looks like you have a CUDA device, but you aren't using it" +
              "\n You can set the parameter of cuda to True.")
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)

# 2. Log
get_logger(args.log_folder, args.log_name)
logger = logging.getLogger(args.log_name)


def get_label_file(filename):
    if not os.path.exists(filename):
        print("The dataset label.txt is empty, We need to create a new one.")
        os.mkdir(filename)
    return filename


def dataset_labels_results(filename, output):
    filename = os.path.join(BASE_DIR, 'data', filename + '_labels.txt')
    get_label_file(filename=filename)
    with open(file=filename, mode='r') as f:
        dict = f.readlines()
        output = output.cpu().numpy()
        output = output[0]
        output = dict[output]
        f.close()
    return output


def test():
    # vgg16, alexnet and lenet5 need to resize image_size, because of fc.
    if args.basenet == 'vgg' or args.basenet == 'alexnet':
        args.image_size = 224
    elif args.basenet == 'lenet':
        args.image_size = 32

    # 3. Ready image
    if args.images_root is None:
        raise ValueError("The images is None, you should load image!")

    image = Image.open(args.images_root)
    transform = transforms.Compose([
        transforms.Resize((args.image_size,
                           args.image_size)),
        transforms.ToTensor()])

    image = transform(image)

    image = image.reshape(1, 3, args.image_size, args.image_size)

    # 4. Define to train mode
    if args.basenet == 'lenet':
        if args.depth == 5:
            model = lenet5(num_classes=args.num_classes)
        else:
            raise ValueError('Unsupported LeNet depth!')
    elif args.basenet == 'alexnet':
        model = alexnet(num_classes=args.num_classes)

    elif args.basenet == 'vgg':
        if args.depth == 11:
            model = vgg11(pretrained=args.pretrained, num_classes=args.num_classes)
        elif args.depth == 13:
            model = vgg13(pretrained=args.pretrained, num_classes=args.num_classes)
        elif args.depth == 16:
            model = vgg16(pretrained=args.pretrained, num_classes=args.num_classes)
        elif args.depth == 19:
            model = vgg19(pretrained=args.pretrained, num_classes=args.num_classes)
        else:
            raise ValueError('Unsupported VGG depth!')

    elif args.basenet == 'resnet':
        if args.depth == 18:
            model = resnet18(pretrained=args.pretrained,
                             num_classes=args.num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=args.pretrained,
                             num_classes=args.num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=args.pretrained,
                             num_classes=args.num_classes)  # False means the models is not trained
        elif args.depth == 101:
            model = resnet101(pretrained=args.pretrained,
                              num_classes=args.num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=args.pretrained,
                              num_classes=args.num_classes)
        else:
            raise ValueError('Unsupported ResNet depth!')
    else:
        raise ValueError('Unsupported model type!')

    if args.cuda:
        model = model.cuda()
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = torch.nn.DataParallel(model)

    # 5. Loading model
    if args.evaluate:
        other, ext = os.path.splitext(args.evaluate)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            model_evaluate_load = os.path.join(args.save_folder, args.evaluate)
            model.load_state_dict(torch.load(model_evaluate_load))
        else:
            print('Sorry only .pth and .pkl files supported.')
    elif args.evaluate is None:
        print("Sorry, you should load weights! ")

    model.eval()

    # 6. print
    logger.info(f"args - {args}")

    # 7. Test
    with torch.no_grad():
        t0 = time.time()
        # 8. Forward
        if args.cuda:
            image = image.cuda()
        output = model(image)
        output = output.argmax(1)
        t1 = time.time()
        m = (t1 - t0) // 60
        s = (t1 - t0) % 60
        folder_name = args.dataset
        output = dataset_labels_results(filename=folder_name, output=output)
        logger.info(f"output: {output}")
        print("It took a total of {}m{}s to complete the testing.".format(int(m), int(s)))
    return output


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    logger.info("Program started")
    output = test()
    logger.info("Done!")

标签

CIFAR_label.txt

{0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'}

ImageNet_label.txt

{0: 'tench, Tinca tinca',
 1: 'goldfish, Carassius auratus',
 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
 3: 'tiger shark, Galeocerdo cuvieri',
 4: 'hammerhead, hammerhead shark',
 5: 'electric ray, crampfish, numbfish, torpedo',
 6: 'stingray',
 7: 'cock',
 8: 'hen',
 9: 'ostrich, Struthio camelus',
 10: 'brambling, Fringilla montifringilla',
 11: 'goldfinch, Carduelis carduelis',
 12: 'house finch, linnet, Carpodacus mexicanus',
 13: 'junco, snowbird',
 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
 15: 'robin, American robin, Turdus migratorius',
 16: 'bulbul',
 17: 'jay',
 18: 'magpie',
 19: 'chickadee',
 20: 'water ouzel, dipper',
 21: 'kite',
 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
 23: 'vulture',
 24: 'great grey owl, great gray owl, Strix nebulosa',
 25: 'European fire salamander, Salamandra salamandra',
 26: 'common newt, Triturus vulgaris',
 27: 'eft',
 28: 'spotted salamander, Ambystoma maculatum',
 29: 'axolotl, mud puppy, Ambystoma mexicanum',
 30: 'bullfrog, Rana catesbeiana',
 31: 'tree frog, tree-frog',
 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
 33: 'loggerhead, loggerhead turtle, Caretta caretta',
 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
 35: 'mud turtle',
 36: 'terrapin',
 37: 'box turtle, box tortoise',
 38: 'banded gecko',
 39: 'common iguana, iguana, Iguana iguana',
 40: 'American chameleon, anole, Anolis carolinensis',
 41: 'whiptail, whiptail lizard',
 42: 'agama',
 43: 'frilled lizard, Chlamydosaurus kingi',
 44: 'alligator lizard',
 45: 'Gila monster, Heloderma suspectum',
 46: 'green lizard, Lacerta viridis',
 47: 'African chameleon, Chamaeleo chamaeleon',
 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
 50: 'American alligator, Alligator mississipiensis',
 51: 'triceratops',
 52: 'thunder snake, worm snake, Carphophis amoenus',
 53: 'ringneck snake, ring-necked snake, ring snake',
 54: 'hognose snake, puff adder, sand viper',
 55: 'green snake, grass snake',
 56: 'king snake, kingsnake',
 57: 'garter snake, grass snake',
 58: 'water snake',
 59: 'vine snake',
 60: 'night snake, Hypsiglena torquata',
 61: 'boa constrictor, Constrictor constrictor',
 62: 'rock python, rock snake, Python sebae',
 63: 'Indian cobra, Naja naja',
 64: 'green mamba',
 65: 'sea snake',
 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
 69: 'trilobite',
 70: 'harvestman, daddy longlegs, Phalangium opilio',
 71: 'scorpion',
 72: 'black and gold garden spider, Argiope aurantia',
 73: 'barn spider, Araneus cavaticus',
 74: 'garden spider, Aranea diademata',
 75: 'black widow, Latrodectus mactans',
 76: 'tarantula',
 77: 'wolf spider, hunting spider',
 78: 'tick',
 79: 'centipede',
 80: 'black grouse',
 81: 'ptarmigan',
 82: 'ruffed grouse, partridge, Bonasa umbellus',
 83: 'prairie chicken, prairie grouse, prairie fowl',
 84: 'peacock',
 85: 'quail',
 86: 'partridge',
 87: 'African grey, African gray, Psittacus erithacus',
 88: 'macaw',
 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
 90: 'lorikeet',
 91: 'coucal',
 92: 'bee eater',
 93: 'hornbill',
 94: 'hummingbird',
 95: 'jacamar',
 96: 'toucan',
 97: 'drake',
 98: 'red-breasted merganser, Mergus serrator',
 99: 'goose',
 100: 'black swan, Cygnus atratus',
 101: 'tusker',
 102: 'echidna, spiny anteater, anteater',
 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
 104: 'wallaby, brush kangaroo',
 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
 106: 'wombat',
 107: 'jellyfish',
 108: 'sea anemone, anemone',
 109: 'brain coral',
 110: 'flatworm, platyhelminth',
 111: 'nematode, nematode worm, roundworm',
 112: 'conch',
 113: 'snail',
 114: 'slug',
 115: 'sea slug, nudibranch',
 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
 117: 'chambered nautilus, pearly nautilus, nautilus',
 118: 'Dungeness crab, Cancer magister',
 119: 'rock crab, Cancer irroratus',
 120: 'fiddler crab',
 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
 124: 'crayfish, crawfish, crawdad, crawdaddy',
 125: 'hermit crab',
 126: 'isopod',
 127: 'white stork, Ciconia ciconia',
 128: 'black stork, Ciconia nigra',
 129: 'spoonbill',
 130: 'flamingo',
 131: 'little blue heron, Egretta caerulea',
 132: 'American egret, great white heron, Egretta albus',
 133: 'bittern',
 134: 'crane',
 135: 'limpkin, Aramus pictus',
 136: 'European gallinule, Porphyrio porphyrio',
 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
 138: 'bustard',
 139: 'ruddy turnstone, Arenaria interpres',
 140: 'red-backed sandpiper, dunlin, Erolia alpina',
 141: 'redshank, Tringa totanus',
 142: 'dowitcher',
 143: 'oystercatcher, oyster catcher',
 144: 'pelican',
 145: 'king penguin, Aptenodytes patagonica',
 146: 'albatross, mollymawk',
 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
 149: 'dugong, Dugong dugon',
 150: 'sea lion',
 151: 'Chihuahua',
 152: 'Japanese spaniel',
 153: 'Maltese dog, Maltese terrier, Maltese',
 154: 'Pekinese, Pekingese, Peke',
 155: 'Shih-Tzu',
 156: 'Blenheim spaniel',
 157: 'papillon',
 158: 'toy terrier',
 159: 'Rhodesian ridgeback',
 160: 'Afghan hound, Afghan',
 161: 'basset, basset hound',
 162: 'beagle',
 163: 'bloodhound, sleuthhound',
 164: 'bluetick',
 165: 'black-and-tan coonhound',
 166: 'Walker hound, Walker foxhound',
 167: 'English foxhound',
 168: 'redbone',
 169: 'borzoi, Russian wolfhound',
 170: 'Irish wolfhound',
 171: 'Italian greyhound',
 172: 'whippet',
 173: 'Ibizan hound, Ibizan Podenco',
 174: 'Norwegian elkhound, elkhound',
 175: 'otterhound, otter hound',
 176: 'Saluki, gazelle hound',
 177: 'Scottish deerhound, deerhound',
 178: 'Weimaraner',
 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
 181: 'Bedlington terrier',
 182: 'Border terrier',
 183: 'Kerry blue terrier',
 184: 'Irish terrier',
 185: 'Norfolk terrier',
 186: 'Norwich terrier',
 187: 'Yorkshire terrier',
 188: 'wire-haired fox terrier',
 189: 'Lakeland terrier',
 190: 'Sealyham terrier, Sealyham',
 191: 'Airedale, Airedale terrier',
 192: 'cairn, cairn terrier',
 193: 'Australian terrier',
 194: 'Dandie Dinmont, Dandie Dinmont terrier',
 195: 'Boston bull, Boston terrier',
 196: 'miniature schnauzer',
 197: 'giant schnauzer',
 198: 'standard schnauzer',
 199: 'Scotch terrier, Scottish terrier, Scottie',
 200: 'Tibetan terrier, chrysanthemum dog',
 201: 'silky terrier, Sydney silky',
 202: 'soft-coated wheaten terrier',
 203: 'West Highland white terrier',
 204: 'Lhasa, Lhasa apso',
 205: 'flat-coated retriever',
 206: 'curly-coated retriever',
 207: 'golden retriever',
 208: 'Labrador retriever',
 209: 'Chesapeake Bay retriever',
 210: 'German short-haired pointer',
 211: 'vizsla, Hungarian pointer',
 212: 'English setter',
 213: 'Irish setter, red setter',
 214: 'Gordon setter',
 215: 'Brittany spaniel',
 216: 'clumber, clumber spaniel',
 217: 'English springer, English springer spaniel',
 218: 'Welsh springer spaniel',
 219: 'cocker spaniel, English cocker spaniel, cocker',
 220: 'Sussex spaniel',
 221: 'Irish water spaniel',
 222: 'kuvasz',
 223: 'schipperke',
 224: 'groenendael',
 225: 'malinois',
 226: 'briard',
 227: 'kelpie',
 228: 'komondor',
 229: 'Old English sheepdog, bobtail',
 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
 231: 'collie',
 232: 'Border collie',
 233: 'Bouvier des Flandres, Bouviers des Flandres',
 234: 'Rottweiler',
 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
 236: 'Doberman, Doberman pinscher',
 237: 'miniature pinscher',
 238: 'Greater Swiss Mountain dog',
 239: 'Bernese mountain dog',
 240: 'Appenzeller',
 241: 'EntleBucher',
 242: 'boxer',
 243: 'bull mastiff',
 244: 'Tibetan mastiff',
 245: 'French bulldog',
 246: 'Great Dane',
 247: 'Saint Bernard, St Bernard',
 248: 'Eskimo dog, husky',
 249: 'malamute, malemute, Alaskan malamute',
 250: 'Siberian husky',
 251: 'dalmatian, coach dog, carriage dog',
 252: 'affenpinscher, monkey pinscher, monkey dog',
 253: 'basenji',
 254: 'pug, pug-dog',
 255: 'Leonberg',
 256: 'Newfoundland, Newfoundland dog',
 257: 'Great Pyrenees',
 258: 'Samoyed, Samoyede',
 259: 'Pomeranian',
 260: 'chow, chow chow',
 261: 'keeshond',
 262: 'Brabancon griffon',
 263: 'Pembroke, Pembroke Welsh corgi',
 264: 'Cardigan, Cardigan Welsh corgi',
 265: 'toy poodle',
 266: 'miniature poodle',
 267: 'standard poodle',
 268: 'Mexican hairless',
 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
 273: 'dingo, warrigal, warragal, Canis dingo',
 274: 'dhole, Cuon alpinus',
 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
 276: 'hyena, hyaena',
 277: 'red fox, Vulpes vulpes',
 278: 'kit fox, Vulpes macrotis',
 279: 'Arctic fox, white fox, Alopex lagopus',
 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
 281: 'tabby, tabby cat',
 282: 'tiger cat',
 283: 'Persian cat',
 284: 'Siamese cat, Siamese',
 285: 'Egyptian cat',
 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
 287: 'lynx, catamount',
 288: 'leopard, Panthera pardus',
 289: 'snow leopard, ounce, Panthera uncia',
 290: 'jaguar, panther, Panthera onca, Felis onca',
 291: 'lion, king of beasts, Panthera leo',
 292: 'tiger, Panthera tigris',
 293: 'cheetah, chetah, Acinonyx jubatus',
 294: 'brown bear, bruin, Ursus arctos',
 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
 298: 'mongoose',
 299: 'meerkat, mierkat',
 300: 'tiger beetle',
 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
 302: 'ground beetle, carabid beetle',
 303: 'long-horned beetle, longicorn, longicorn beetle',
 304: 'leaf beetle, chrysomelid',
 305: 'dung beetle',
 306: 'rhinoceros beetle',
 307: 'weevil',
 308: 'fly',
 309: 'bee',
 310: 'ant, emmet, pismire',
 311: 'grasshopper, hopper',
 312: 'cricket',
 313: 'walking stick, walkingstick, stick insect',
 314: 'cockroach, roach',
 315: 'mantis, mantid',
 316: 'cicada, cicala',
 317: 'leafhopper',
 318: 'lacewing, lacewing fly',
 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
 320: 'damselfly',
 321: 'admiral',
 322: 'ringlet, ringlet butterfly',
 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
 324: 'cabbage butterfly',
 325: 'sulphur butterfly, sulfur butterfly',
 326: 'lycaenid, lycaenid butterfly',
 327: 'starfish, sea star',
 328: 'sea urchin',
 329: 'sea cucumber, holothurian',
 330: 'wood rabbit, cottontail, cottontail rabbit',
 331: 'hare',
 332: 'Angora, Angora rabbit',
 333: 'hamster',
 334: 'porcupine, hedgehog',
 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
 336: 'marmot',
 337: 'beaver',
 338: 'guinea pig, Cavia cobaya',
 339: 'sorrel',
 340: 'zebra',
 341: 'hog, pig, grunter, squealer, Sus scrofa',
 342: 'wild boar, boar, Sus scrofa',
 343: 'warthog',
 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
 345: 'ox',
 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
 347: 'bison',
 348: 'ram, tup',
 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
 350: 'ibex, Capra ibex',
 351: 'hartebeest',
 352: 'impala, Aepyceros melampus',
 353: 'gazelle',
 354: 'Arabian camel, dromedary, Camelus dromedarius',
 355: 'llama',
 356: 'weasel',
 357: 'mink',
 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
 359: 'black-footed ferret, ferret, Mustela nigripes',
 360: 'otter',
 361: 'skunk, polecat, wood pussy',
 362: 'badger',
 363: 'armadillo',
 364: 'three-toed sloth, ai, Bradypus tridactylus',
 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
 366: 'gorilla, Gorilla gorilla',
 367: 'chimpanzee, chimp, Pan troglodytes',
 368: 'gibbon, Hylobates lar',
 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
 370: 'guenon, guenon monkey',
 371: 'patas, hussar monkey, Erythrocebus patas',
 372: 'baboon',
 373: 'macaque',
 374: 'langur',
 375: 'colobus, colobus monkey',
 376: 'proboscis monkey, Nasalis larvatus',
 377: 'marmoset',
 378: 'capuchin, ringtail, Cebus capucinus',
 379: 'howler monkey, howler',
 380: 'titi, titi monkey',
 381: 'spider monkey, Ateles geoffroyi',
 382: 'squirrel monkey, Saimiri sciureus',
 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
 384: 'indri, indris, Indri indri, Indri brevicaudatus',
 385: 'Indian elephant, Elephas maximus',
 386: 'African elephant, Loxodonta africana',
 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
 389: 'barracouta, snoek',
 390: 'eel',
 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
 392: 'rock beauty, Holocanthus tricolor',
 393: 'anemone fish',
 394: 'sturgeon',
 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
 396: 'lionfish',
 397: 'puffer, pufferfish, blowfish, globefish',
 398: 'abacus',
 399: 'abaya',
 400: "academic gown, academic robe, judge's robe",
 401: 'accordion, piano accordion, squeeze box',
 402: 'acoustic guitar',
 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
 404: 'airliner',
 405: 'airship, dirigible',
 406: 'altar',
 407: 'ambulance',
 408: 'amphibian, amphibious vehicle',
 409: 'analog clock',
 410: 'apiary, bee house',
 411: 'apron',
 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
 413: 'assault rifle, assault gun',
 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
 415: 'bakery, bakeshop, bakehouse',
 416: 'balance beam, beam',
 417: 'balloon',
 418: 'ballpoint, ballpoint pen, ballpen, Biro',
 419: 'Band Aid',
 420: 'banjo',
 421: 'bannister, banister, balustrade, balusters, handrail',
 422: 'barbell',
 423: 'barber chair',
 424: 'barbershop',
 425: 'barn',
 426: 'barometer',
 427: 'barrel, cask',
 428: 'barrow, garden cart, lawn cart, wheelbarrow',
 429: 'baseball',
 430: 'basketball',
 431: 'bassinet',
 432: 'bassoon',
 433: 'bathing cap, swimming cap',
 434: 'bath towel',
 435: 'bathtub, bathing tub, bath, tub',
 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
 437: 'beacon, lighthouse, beacon light, pharos',
 438: 'beaker',
 439: 'bearskin, busby, shako',
 440: 'beer bottle',
 441: 'beer glass',
 442: 'bell cote, bell cot',
 443: 'bib',
 444: 'bicycle-built-for-two, tandem bicycle, tandem',
 445: 'bikini, two-piece',
 446: 'binder, ring-binder',
 447: 'binoculars, field glasses, opera glasses',
 448: 'birdhouse',
 449: 'boathouse',
 450: 'bobsled, bobsleigh, bob',
 451: 'bolo tie, bolo, bola tie, bola',
 452: 'bonnet, poke bonnet',
 453: 'bookcase',
 454: 'bookshop, bookstore, bookstall',
 455: 'bottlecap',
 456: 'bow',
 457: 'bow tie, bow-tie, bowtie',
 458: 'brass, memorial tablet, plaque',
 459: 'brassiere, bra, bandeau',
 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
 461: 'breastplate, aegis, egis',
 462: 'broom',
 463: 'bucket, pail',
 464: 'buckle',
 465: 'bulletproof vest',
 466: 'bullet train, bullet',
 467: 'butcher shop, meat market',
 468: 'cab, hack, taxi, taxicab',
 469: 'caldron, cauldron',
 470: 'candle, taper, wax light',
 471: 'cannon',
 472: 'canoe',
 473: 'can opener, tin opener',
 474: 'cardigan',
 475: 'car mirror',
 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
 477: "carpenter's kit, tool kit",
 478: 'carton',
 479: 'car wheel',
 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
 481: 'cassette',
 482: 'cassette player',
 483: 'castle',
 484: 'catamaran',
 485: 'CD player',
 486: 'cello, violoncello',
 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
 488: 'chain',
 489: 'chainlink fence',
 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
 491: 'chain saw, chainsaw',
 492: 'chest',
 493: 'chiffonier, commode',
 494: 'chime, bell, gong',
 495: 'china cabinet, china closet',
 496: 'Christmas stocking',
 497: 'church, church building',
 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
 499: 'cleaver, meat cleaver, chopper',
 500: 'cliff dwelling',
 501: 'cloak',
 502: 'clog, geta, patten, sabot',
 503: 'cocktail shaker',
 504: 'coffee mug',
 505: 'coffeepot',
 506: 'coil, spiral, volute, whorl, helix',
 507: 'combination lock',
 508: 'computer keyboard, keypad',
 509: 'confectionery, confectionary, candy store',
 510: 'container ship, containership, container vessel',
 511: 'convertible',
 512: 'corkscrew, bottle screw',
 513: 'cornet, horn, trumpet, trump',
 514: 'cowboy boot',
 515: 'cowboy hat, ten-gallon hat',
 516: 'cradle',
 517: 'crane',
 518: 'crash helmet',
 519: 'crate',
 520: 'crib, cot',
 521: 'Crock Pot',
 522: 'croquet ball',
 523: 'crutch',
 524: 'cuirass',
 525: 'dam, dike, dyke',
 526: 'desk',
 527: 'desktop computer',
 528: 'dial telephone, dial phone',
 529: 'diaper, nappy, napkin',
 530: 'digital clock',
 531: 'digital watch',
 532: 'dining table, board',
 533: 'dishrag, dishcloth',
 534: 'dishwasher, dish washer, dishwashing machine',
 535: 'disk brake, disc brake',
 536: 'dock, dockage, docking facility',
 537: 'dogsled, dog sled, dog sleigh',
 538: 'dome',
 539: 'doormat, welcome mat',
 540: 'drilling platform, offshore rig',
 541: 'drum, membranophone, tympan',
 542: 'drumstick',
 543: 'dumbbell',
 544: 'Dutch oven',
 545: 'electric fan, blower',
 546: 'electric guitar',
 547: 'electric locomotive',
 548: 'entertainment center',
 549: 'envelope',
 550: 'espresso maker',
 551: 'face powder',
 552: 'feather boa, boa',
 553: 'file, file cabinet, filing cabinet',
 554: 'fireboat',
 555: 'fire engine, fire truck',
 556: 'fire screen, fireguard',
 557: 'flagpole, flagstaff',
 558: 'flute, transverse flute',
 559: 'folding chair',
 560: 'football helmet',
 561: 'forklift',
 562: 'fountain',
 563: 'fountain pen',
 564: 'four-poster',
 565: 'freight car',
 566: 'French horn, horn',
 567: 'frying pan, frypan, skillet',
 568: 'fur coat',
 569: 'garbage truck, dustcart',
 570: 'gasmask, respirator, gas helmet',
 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
 572: 'goblet',
 573: 'go-kart',
 574: 'golf ball',
 575: 'golfcart, golf cart',
 576: 'gondola',
 577: 'gong, tam-tam',
 578: 'gown',
 579: 'grand piano, grand',
 580: 'greenhouse, nursery, glasshouse',
 581: 'grille, radiator grille',
 582: 'grocery store, grocery, food market, market',
 583: 'guillotine',
 584: 'hair slide',
 585: 'hair spray',
 586: 'half track',
 587: 'hammer',
 588: 'hamper',
 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
 590: 'hand-held computer, hand-held microcomputer',
 591: 'handkerchief, hankie, hanky, hankey',
 592: 'hard disc, hard disk, fixed disk',
 593: 'harmonica, mouth organ, harp, mouth harp',
 594: 'harp',
 595: 'harvester, reaper',
 596: 'hatchet',
 597: 'holster',
 598: 'home theater, home theatre',
 599: 'honeycomb',
 600: 'hook, claw',
 601: 'hoopskirt, crinoline',
 602: 'horizontal bar, high bar',
 603: 'horse cart, horse-cart',
 604: 'hourglass',
 605: 'iPod',
 606: 'iron, smoothing iron',
 607: "jack-o'-lantern",
 608: 'jean, blue jean, denim',
 609: 'jeep, landrover',
 610: 'jersey, T-shirt, tee shirt',
 611: 'jigsaw puzzle',
 612: 'jinrikisha, ricksha, rickshaw',
 613: 'joystick',
 614: 'kimono',
 615: 'knee pad',
 616: 'knot',
 617: 'lab coat, laboratory coat',
 618: 'ladle',
 619: 'lampshade, lamp shade',
 620: 'laptop, laptop computer',
 621: 'lawn mower, mower',
 622: 'lens cap, lens cover',
 623: 'letter opener, paper knife, paperknife',
 624: 'library',
 625: 'lifeboat',
 626: 'lighter, light, igniter, ignitor',
 627: 'limousine, limo',
 628: 'liner, ocean liner',
 629: 'lipstick, lip rouge',
 630: 'Loafer',
 631: 'lotion',
 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
 633: "loupe, jeweler's loupe",
 634: 'lumbermill, sawmill',
 635: 'magnetic compass',
 636: 'mailbag, postbag',
 637: 'mailbox, letter box',
 638: 'maillot',
 639: 'maillot, tank suit',
 640: 'manhole cover',
 641: 'maraca',
 642: 'marimba, xylophone',
 643: 'mask',
 644: 'matchstick',
 645: 'maypole',
 646: 'maze, labyrinth',
 647: 'measuring cup',
 648: 'medicine chest, medicine cabinet',
 649: 'megalith, megalithic structure',
 650: 'microphone, mike',
 651: 'microwave, microwave oven',
 652: 'military uniform',
 653: 'milk can',
 654: 'minibus',
 655: 'miniskirt, mini',
 656: 'minivan',
 657: 'missile',
 658: 'mitten',
 659: 'mixing bowl',
 660: 'mobile home, manufactured home',
 661: 'Model T',
 662: 'modem',
 663: 'monastery',
 664: 'monitor',
 665: 'moped',
 666: 'mortar',
 667: 'mortarboard',
 668: 'mosque',
 669: 'mosquito net',
 670: 'motor scooter, scooter',
 671: 'mountain bike, all-terrain bike, off-roader',
 672: 'mountain tent',
 673: 'mouse, computer mouse',
 674: 'mousetrap',
 675: 'moving van',
 676: 'muzzle',
 677: 'nail',
 678: 'neck brace',
 679: 'necklace',
 680: 'nipple',
 681: 'notebook, notebook computer',
 682: 'obelisk',
 683: 'oboe, hautboy, hautbois',
 684: 'ocarina, sweet potato',
 685: 'odometer, hodometer, mileometer, milometer',
 686: 'oil filter',
 687: 'organ, pipe organ',
 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
 689: 'overskirt',
 690: 'oxcart',
 691: 'oxygen mask',
 692: 'packet',
 693: 'paddle, boat paddle',
 694: 'paddlewheel, paddle wheel',
 695: 'padlock',
 696: 'paintbrush',
 697: "pajama, pyjama, pj's, jammies",
 698: 'palace',
 699: 'panpipe, pandean pipe, syrinx',
 700: 'paper towel',
 701: 'parachute, chute',
 702: 'parallel bars, bars',
 703: 'park bench',
 704: 'parking meter',
 705: 'passenger car, coach, carriage',
 706: 'patio, terrace',
 707: 'pay-phone, pay-station',
 708: 'pedestal, plinth, footstall',
 709: 'pencil box, pencil case',
 710: 'pencil sharpener',
 711: 'perfume, essence',
 712: 'Petri dish',
 713: 'photocopier',
 714: 'pick, plectrum, plectron',
 715: 'pickelhaube',
 716: 'picket fence, paling',
 717: 'pickup, pickup truck',
 718: 'pier',
 719: 'piggy bank, penny bank',
 720: 'pill bottle',
 721: 'pillow',
 722: 'ping-pong ball',
 723: 'pinwheel',
 724: 'pirate, pirate ship',
 725: 'pitcher, ewer',
 726: "plane, carpenter's plane, woodworking plane",
 727: 'planetarium',
 728: 'plastic bag',
 729: 'plate rack',
 730: 'plow, plough',
 731: "plunger, plumber's helper",
 732: 'Polaroid camera, Polaroid Land camera',
 733: 'pole',
 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
 735: 'poncho',
 736: 'pool table, billiard table, snooker table',
 737: 'pop bottle, soda bottle',
 738: 'pot, flowerpot',
 739: "potter's wheel",
 740: 'power drill',
 741: 'prayer rug, prayer mat',
 742: 'printer',
 743: 'prison, prison house',
 744: 'projectile, missile',
 745: 'projector',
 746: 'puck, hockey puck',
 747: 'punching bag, punch bag, punching ball, punchball',
 748: 'purse',
 749: 'quill, quill pen',
 750: 'quilt, comforter, comfort, puff',
 751: 'racer, race car, racing car',
 752: 'racket, racquet',
 753: 'radiator',
 754: 'radio, wireless',
 755: 'radio telescope, radio reflector',
 756: 'rain barrel',
 757: 'recreational vehicle, RV, R.V.',
 758: 'reel',
 759: 'reflex camera',
 760: 'refrigerator, icebox',
 761: 'remote control, remote',
 762: 'restaurant, eating house, eating place, eatery',
 763: 'revolver, six-gun, six-shooter',
 764: 'rifle',
 765: 'rocking chair, rocker',
 766: 'rotisserie',
 767: 'rubber eraser, rubber, pencil eraser',
 768: 'rugby ball',
 769: 'rule, ruler',
 770: 'running shoe',
 771: 'safe',
 772: 'safety pin',
 773: 'saltshaker, salt shaker',
 774: 'sandal',
 775: 'sarong',
 776: 'sax, saxophone',
 777: 'scabbard',
 778: 'scale, weighing machine',
 779: 'school bus',
 780: 'schooner',
 781: 'scoreboard',
 782: 'screen, CRT screen',
 783: 'screw',
 784: 'screwdriver',
 785: 'seat belt, seatbelt',
 786: 'sewing machine',
 787: 'shield, buckler',
 788: 'shoe shop, shoe-shop, shoe store',
 789: 'shoji',
 790: 'shopping basket',
 791: 'shopping cart',
 792: 'shovel',
 793: 'shower cap',
 794: 'shower curtain',
 795: 'ski',
 796: 'ski mask',
 797: 'sleeping bag',
 798: 'slide rule, slipstick',
 799: 'sliding door',
 800: 'slot, one-armed bandit',
 801: 'snorkel',
 802: 'snowmobile',
 803: 'snowplow, snowplough',
 804: 'soap dispenser',
 805: 'soccer ball',
 806: 'sock',
 807: 'solar dish, solar collector, solar furnace',
 808: 'sombrero',
 809: 'soup bowl',
 810: 'space bar',
 811: 'space heater',
 812: 'space shuttle',
 813: 'spatula',
 814: 'speedboat',
 815: "spider web, spider's web",
 816: 'spindle',
 817: 'sports car, sport car',
 818: 'spotlight, spot',
 819: 'stage',
 820: 'steam locomotive',
 821: 'steel arch bridge',
 822: 'steel drum',
 823: 'stethoscope',
 824: 'stole',
 825: 'stone wall',
 826: 'stopwatch, stop watch',
 827: 'stove',
 828: 'strainer',
 829: 'streetcar, tram, tramcar, trolley, trolley car',
 830: 'stretcher',
 831: 'studio couch, day bed',
 832: 'stupa, tope',
 833: 'submarine, pigboat, sub, U-boat',
 834: 'suit, suit of clothes',
 835: 'sundial',
 836: 'sunglass',
 837: 'sunglasses, dark glasses, shades',
 838: 'sunscreen, sunblock, sun blocker',
 839: 'suspension bridge',
 840: 'swab, swob, mop',
 841: 'sweatshirt',
 842: 'swimming trunks, bathing trunks',
 843: 'swing',
 844: 'switch, electric switch, electrical switch',
 845: 'syringe',
 846: 'table lamp',
 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
 848: 'tape player',
 849: 'teapot',
 850: 'teddy, teddy bear',
 851: 'television, television system',
 852: 'tennis ball',
 853: 'thatch, thatched roof',
 854: 'theater curtain, theatre curtain',
 855: 'thimble',
 856: 'thresher, thrasher, threshing machine',
 857: 'throne',
 858: 'tile roof',
 859: 'toaster',
 860: 'tobacco shop, tobacconist shop, tobacconist',
 861: 'toilet seat',
 862: 'torch',
 863: 'totem pole',
 864: 'tow truck, tow car, wrecker',
 865: 'toyshop',
 866: 'tractor',
 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
 868: 'tray',
 869: 'trench coat',
 870: 'tricycle, trike, velocipede',
 871: 'trimaran',
 872: 'tripod',
 873: 'triumphal arch',
 874: 'trolleybus, trolley coach, trackless trolley',
 875: 'trombone',
 876: 'tub, vat',
 877: 'turnstile',
 878: 'typewriter keyboard',
 879: 'umbrella',
 880: 'unicycle, monocycle',
 881: 'upright, upright piano',
 882: 'vacuum, vacuum cleaner',
 883: 'vase',
 884: 'vault',
 885: 'velvet',
 886: 'vending machine',
 887: 'vestment',
 888: 'viaduct',
 889: 'violin, fiddle',
 890: 'volleyball',
 891: 'waffle iron',
 892: 'wall clock',
 893: 'wallet, billfold, notecase, pocketbook',
 894: 'wardrobe, closet, press',
 895: 'warplane, military plane',
 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
 897: 'washer, automatic washer, washing machine',
 898: 'water bottle',
 899: 'water jug',
 900: 'water tower',
 901: 'whiskey jug',
 902: 'whistle',
 903: 'wig',
 904: 'window screen',
 905: 'window shade',
 906: 'Windsor tie',
 907: 'wine bottle',
 908: 'wing',
 909: 'wok',
 910: 'wooden spoon',
 911: 'wool, woolen, woollen',
 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
 913: 'wreck',
 914: 'yawl',
 915: 'yurt',
 916: 'web site, website, internet site, site',
 917: 'comic book',
 918: 'crossword puzzle, crossword',
 919: 'street sign',
 920: 'traffic light, traffic signal, stoplight',
 921: 'book jacket, dust cover, dust jacket, dust wrapper',
 922: 'menu',
 923: 'plate',
 924: 'guacamole',
 925: 'consomme',
 926: 'hot pot, hotpot',
 927: 'trifle',
 928: 'ice cream, icecream',
 929: 'ice lolly, lolly, lollipop, popsicle',
 930: 'French loaf',
 931: 'bagel, beigel',
 932: 'pretzel',
 933: 'cheeseburger',
 934: 'hotdog, hot dog, red hot',
 935: 'mashed potato',
 936: 'head cabbage',
 937: 'broccoli',
 938: 'cauliflower',
 939: 'zucchini, courgette',
 940: 'spaghetti squash',
 941: 'acorn squash',
 942: 'butternut squash',
 943: 'cucumber, cuke',
 944: 'artichoke, globe artichoke',
 945: 'bell pepper',
 946: 'cardoon',
 947: 'mushroom',
 948: 'Granny Smith',
 949: 'strawberry',
 950: 'orange',
 951: 'lemon',
 952: 'fig',
 953: 'pineapple, ananas',
 954: 'banana',
 955: 'jackfruit, jak, jack',
 956: 'custard apple',
 957: 'pomegranate',
 958: 'hay',
 959: 'carbonara',
 960: 'chocolate sauce, chocolate syrup',
 961: 'dough',
 962: 'meat loaf, meatloaf',
 963: 'pizza, pizza pie',
 964: 'potpie',
 965: 'burrito',
 966: 'red wine',
 967: 'espresso',
 968: 'cup',
 969: 'eggnog',
 970: 'alp',
 971: 'bubble',
 972: 'cliff, drop, drop-off',
 973: 'coral reef',
 974: 'geyser',
 975: 'lakeside, lakeshore',
 976: 'promontory, headland, head, foreland',
 977: 'sandbar, sand bar',
 978: 'seashore, coast, seacoast, sea-coast',
 979: 'valley, vale',
 980: 'volcano',
 981: 'ballplayer, baseball player',
 982: 'groom, bridegroom',
 983: 'scuba diver',
 984: 'rapeseed',
 985: 'daisy',
 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
 987: 'corn',
 988: 'acorn',
 989: 'hip, rose hip, rosehip',
 990: 'buckeye, horse chestnut, conker',
 991: 'coral fungus',
 992: 'agaric',
 993: 'gyromitra',
 994: 'stinkhorn, carrion fungus',
 995: 'earthstar',
 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
 997: 'bolete',
 998: 'ear, spike, capitulum',
 999: 'toilet tissue, toilet paper, bathroom tissue'}

运行结果

1.LeNet5

basenet: lenet5 (image size: 32 * 32 * 3)
dataset: cifar
len(dataset): 50000, iter_size: 1562 
batch_size: 32
optim: SGD
scheduler: MultiStepLR
milestones: [15, 20, 30]
weight_decay: 1e-4
gamma: 0.1
momentum: 0.9
lr: 0.01
epoch: 30
epochtimestop1 acc (%)top5 acc (%)
00h0min23s50.0093.75
10h0min21s62.5096.88
20h0min24s65.6296.88
30h0min21s53.1296.88
............
290h0min23s75.00100.00

共计

epochstimesavg top1 acc (%)avg top5 acc (%)
300h11m44s62.20853333333333595.97033333333

2.AlexNet

basenet: alexnet (image size: 224 * 224 * 3)
dataset: cifar
len(dataset): 50000, iter_size: 1562 
batch_size: 32
optim: SGD
scheduler: MultiStepLR
milestones: [15, 20, 30]
weight_decay: 1e-4
gamma: 0.1
momentum: 0.9
lr:0.01
epoch: 30 
epochtimestop1 acc (%)top5 acc (%)
00h0min45s50.0090.62
10h0min44s62.5093.75
20h0min46s68.7596.88
30h0min44s62.50100.00
............
290h0min42s100.00100.00

共计

epochstimesavg top1 acc (%)avg top5 acc (%)
300h22m44s86.2745333333333498.99946666666666

3.VGG

basenet: vgg16 (image size: 224 * 224 * 3)
dataset: cifar
len(dataset): 50000, iter_size: 1562 
batch_size: 32
optim: SGD
scheduler: MultiStepLR
milestones: [15, 20, 30]
weight_decay: 1e-4
gamma: 0.1
momentum: 0.9
lr:0.01
epoch: 30 
epochtimestop1 acc (%)top5 acc (%)
00h2min46s25.0071.88
10h2min45s53.1287.50
20h2min44s40.6296.88
30h2min42s34.3890.62
............
290h2min44s100.00100.00

共计

epochstimesavg top1 acc (%)avg top5 acc (%)
301h23m43s76.5560666666666796.441

4.ResNet

basenet: resnet18
dataset: ImageNet
image size: 224 * 224 * 3 (可自定义)
batch_size: 32
optim: SGD
scheduler: MultiStepLR
milestones: [15, 20, 30]
weight_decay: 1e-4
gamma: 0.1
momentum: 0.9
lr:0.001
epoch: 30
epochtimestop1 acc (%)top5 acc (%)
04h22min38s28.1243.75
13h59min35s34.3859.38
23h58min0s65.6284.38
33h48min56s46.8875.00
43h54min36s53.1275.00
53h49min35s56.2571.88
............

未完...

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-29 19:04:19  更:2022-06-29 19:04:42 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 9:05:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计