IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 使用Pytorch自带模型预测图片 -> 正文阅读

[人工智能]使用Pytorch自带模型预测图片

如下是博主使用内置模型vgg16对一张图片(拿的VOC2007数据集里面的图片,按照道理应该拿imagenet的图片,如下mean和std也是用的imagenet数据集上的统计)进行预测的代码

import torch
import torchvision
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import matplotlib.pyplot as plt

vgg16 = torchvision.models.vgg16(pretrained=True).cuda()

#
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(), normalize]
            )
img = Image.open("2008_002682.jpg")
print(img.size)

#对图像进行归一化
img_p = transform(img)
print(img_p.shape)

#增加一个维度
img_normalize = torch.unsqueeze(img_p,0).cuda()
print(img_normalize.shape)

vgg16.eval()

out = vgg16(img_normalize)

#最后一层是1000的一维向量,每一个表示对应类别的概率
print(out.shape)

with open('imagenet_classes.txt') as f:
    classes = [line.strip() for line in f.readlines()]

_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
prediction = [[classes[idx], percentage[idx].item()] for idx in indices[0][:5]]
print(prediction)

score = []
label = []
for i in prediction:
    print('Prediciton-> {:<25} Accuracy-> ({:.2f}%)'.format(i[0][:], i[1]))
    score.append(i[1])
    label.append(i[0])

print(score)

#把结果show出来,一些用法和matlab很相似
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 8))
fig.sca(ax1)
ax1.imshow(img)
plt.xticks([])
plt.yticks([])

barlist = ax2.bar(range(5), [i for i in score])
barlist[0].set_color('g')
plt.sca(ax2)
plt.ylim([0, 20])

plt.xticks(range(5),
           # [idx2labels[str(i)][1] for i in pred_label_idx],
           [i for i in label],
           rotation='45')
# fig.subplots_adjust(bottom=0.2)
plt.rcParams['font.size'] = '16'  # 设置字体大小
plt.rcParams['axes.unicode_minus'] = False   # 解决坐标轴负数的负号显示问题
plt.show()

?所用的imagenet_classes.txt可从网址进行下载,测试图片样子如下:

?预测结果如下(可知结果是正确的):

?对比博主之前的博文

深度学习平台实现Demo(八) - c#调用python方式完成训练和预测_jiugeshao的专栏-CSDN博客https://blog.csdn.net/jiugeshao/article/details/112093981该博文是keras框架实现了一个预测,对比下来,大概的过程类似,方法差不多。

附:

?pytorch自带了大量内置模型,相关介绍可见如下博客

pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式_shuijinghua的博客-CSDN博客_pth和pt

?Pytorch的内置模型_博客-CSDN博客_pytorch内置模型

pytorch框架--网络方面--pytorch自带模型(增、改)_雪剑封心-CSDN博客

pytorch 如何调用cuda_将Pytorch模型从CPU转换成GPU的实现方法_扎波罗热人的博客-CSDN博客

?Pytorch 高效使用GPU的操作 - 南鹤- - 博客园
pytorch提供的网络模型(预测图片类别)_z1139269312的博客-CSDN博客

?pytorch下一些常用的操作可见如下代码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# author:Icecream.Shao
import torch
import torchvision
from torch import nn

print(torch.cuda.is_available()) #判断是否支持cuda
print(torch.cuda.device_count()) #当前支持cuda的硬件个数
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #选择一个gpu
print(device)
used_gpu_name = torch.cuda.get_device_name(device) #获取所选择的gpu名字
print(used_gpu_name)

#print的常用语法
n=3
print('The squre of',n,'is',n*n)
print('The squre of ' + str(n) + ' is ' + str(n*n))
print('The squre of %s is %s' % (n,n*n))
print('The squre of {1} is {0}'.format(n*n, n))
print(f'model cost:{0.3:.3f}s')

#内置模型的加载方法
#vgg16 = torchvision.models.vgg16(pretrained=True).cuda()
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
print(vgg16)

#内置数据集的获取方法
train_data = torchvision.datasets.CIFAR10("./data", train=True,transform=torchvision.transforms.ToTensor,download=True)

#增加层以及修改层参数
vgg16.classifier.add_module('my_linear', nn.Linear(1000, 10))
vgg16.classifier[7] = nn.Linear(1000,2)
print(vgg16)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-03-08 22:28:33  更:2022-03-08 22:33:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 17:44:23-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码