IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> CV(计算机视觉)领域四大类之图像分类一(VGG16)(含论文和源码) -> 正文阅读

[人工智能]CV(计算机视觉)领域四大类之图像分类一(VGG16)(含论文和源码)

今天给大家带来的是VGG16深度网络的解析,虽然VGG比较出名的还有VGG19深度网络,但是本文只对VGG16网络做一下解析,至于其他变种,稍微修改一下网络层数和参数即可实现,如果这个系列比较复杂的VGG16网络可以自己搭建出来,那么其他的变种也肯定没问题。

代码git地址:vgg/pytorch · 小四/xs-dl - 码云 - 开源中国 (gitee.com)

1 VGG网络结构

论文地址:1409.1556.pdf (arxiv.org)

VGG算法团队研究了卷积网络深度在大规模的图像识别环境下对准确性的影响。其主要贡献是使用非常小的(3×3)卷积滤波器架构对网络深度的增加进行了全面评估,这表明通过将深度推到16-19加权层可以实现对现有技术配置的显著改进。也正是因为这些改进,使得VGG算法团队在定位和分类过程中分别获得了第一名和第二名。如图可以看到VGG16包含13个卷积层和三个全连接层,VGG19包含16个卷积层和三个全连接层。接下来我们会详细的剖析整个网络的搭建过程。

2 VGG16网络参数

上图的参数是根据

原论文的数据和源代码中的参数反推出来的,大部分参数直接根据原论文的网络图即可获得,其中其他参数的计算可以参照我另个一篇博文:CV(计算机视觉)领域四大类之图像分类一(AlexNet) - 知乎 (zhihu.com),里边有详细的参数计算过程。

?

3 pytarch实现

(1)数据准备make_data.py(花分类数据集)

import os

from PIL import Image

# 下载数据集地址(手动下载解压即可)
DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

# 解压数据集的路径(自己定义即可)
flower_photos = "G:\\alexnet\\flower_photos\\"
# 训练数据路径(自己定义即可)
base_url = "G:\\alexnet\\train_data\\"

for item in os.listdir(flower_photos):
    path_temp = flower_photos + item
    n = 0
    for name in os.listdir(path_temp):
        n += 1
        img = Image.open(path_temp + "\\" + name)
        # 转换通道
        img = img.convert("RGB")
        # 验证集(20%验证集,80%数据集,可自行调节)
        if n % 8 == 0:
            if not os.path.exists(base_url + "val\\" + item):
                os.makedirs(base_url + "val\\" + item)
            img.save(base_url + "val\\" + item + "\\" + name)
        else:
            if not os.path.exists(base_url + "train\\" + item):
                os.makedirs(base_url + "train\\" + item)
            img.save(base_url + "train\\" + item + "\\" + name)

(2)vgg16模型网络(model.py)

import torch
import torch.nn as nn


class VGG(nn.Module):
    def __init__(self, num_classes=5):
        super(VGG, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

(3)训练文件(train.py)

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from model import VGG


def main():
    # 确定是否可以启动GPU训练
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # (如果没有GPU环境则可以直接选择CPU训练)
    device = torch.device("cpu")

    # 设置训练集和验证集格变换规则
    train_form = transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    val_form = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 训练集
    image_path = "G:\\alexnet\\train_data\\"
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=train_form)

    # 由于我本机只有一个显卡,所以num_workers设置为0了
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)

    # 验证集
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=val_form)
    val_num = len(validate_dataset)

    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=2, shuffle=False, num_workers=0)
    net = VGG(num_classes=5)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 30
    best_acc = 0.0
    save_path = './vgg16-Net.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

(4)预测文件(predict.py)

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import VGG


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "1.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)
    
    # create model
    model = vgg(model_name="vgg16", num_classes=5).to(device)
    # load model weights
    weights_path = "./vgg16Net.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

4 VGG网络优势

(1)VGG网络采用重复堆叠的小卷积核替代大卷积核,在保证具有相同感受野的条件下,减少了网络的参数,提升了网络的深度,从而提升网络特征提取的能力。三个3*3的卷积核和一个7*7感受野相同,但是参数仅仅是后者的55%。

(2)提升的网络深度都使用ReLU激活函数:提升非线性变化的能力

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-08 20:42:06  更:2022-10-08 20:46:58 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/29 7:29:54-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码