IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 利用预训练Inception-v3模型权重进行特征提取(图像识别) -> 正文阅读

[Python知识库]利用预训练Inception-v3模型权重进行特征提取(图像识别)

一、Tensorflow+Keras
参考文章:

https://blog.csdn.net/baidu_36669549/article/details/84993772

在tensorflow官网的图像识别的中文介绍中,介绍了如何用Tensorflow的模型代码库中的classify_image.py进行图像识别。里面有介绍如何测试,而且还提供了最后一层的112048维的特征提取方式,所以在这里介绍一下。

......
 
with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
 
.......

源码就是在这里进行的介绍,有三个接口,

‘softmax:0’: A tensor containing the normalized prediction across 1000 labels.

‘pool_3:0’: A tensor containing the next-to-last layer containing 2048 float description of the image.

‘DecodeJpeg/contents:0’: A tensor containing a string providing JPEG encoding of the image.

预测的话直接调’softmax:0’:和’DecodeJpeg/contents:0’:可以进行图像识别的测试

如果想要提取特征就像这样

            fc_tensor = sess.graph.get_tensor_by_name('pool_3:0')
            pool_1 = sess.run(fc_tensor,{'DecodeJpeg/contents:0': image_data})

就可以,保存的话可以选择CSV或者.mat文件

import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import scipy.io as scio
 
model_dir='F:/fqh/models-master/tutorials/image/imagenet/2015'
image = 'F:/fqh/models-master/tutorials/image/imagenet/data_set/face/faces96_152_20_180-200jpgview-depth/'
 
target_path=image+'wjhugh/'
class NodeLookup(object):
    def __init__(self, label_lookup_path=None, uid_lookup_path=None):
        if not label_lookup_path:
            label_lookup_path = os.path.join(
                    model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
        if not uid_lookup_path:
            uid_lookup_path = os.path.join(
                    model_dir, 'imagenet_synset_to_human_label_map.txt')
        self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
 
    def load(self, label_lookup_path, uid_lookup_path):
        if not tf.gfile.Exists(uid_lookup_path):
            tf.logging.fatal('File does not exist %s', uid_lookup_path)
        if not tf.gfile.Exists(label_lookup_path):
            tf.logging.fatal('File does not exist %s', label_lookup_path)
 
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
        
        uid_to_human = {}
        for line in proto_as_ascii_lines:
 
            line = line.strip('\n')
 
            parse_items = line.split('\t')
 
            uid = parse_items[0]
 
            human_string = parse_items[1]
 
            uid_to_human[uid] = human_string
            
 
 
        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
 
        node_id_to_uid = {}
        for line in proto_as_ascii:
 
            if line.startswith('  target_class:'):
 
                target_class = int(line.split(': ')[1])
            if line.startswith('  target_class_string:'):
 
                target_class_string = line.split(': ')[1]
 
                node_id_to_uid[target_class] = target_class_string[1:-2]
    
 
        node_id_to_name = {}
        for key, val in node_id_to_uid.items():
 
            if val not in uid_to_human:
                tf.logging.fatal('Failed to locate: %s', val)
 
            name = uid_to_human[val]
 
            node_id_to_name[key] = name
    
        return node_id_to_name
 
 
    def id_to_string(self, node_id):
 
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]
 
 
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
 
create_graph()
list0=[]
for root, dirs,files in os.walk(image):  
    list0.append(dirs)
#print(list0[0])
img_list=[]
# print(img_list)
for ii in list0[0]:
    img_list.append(ii)
list_img_name=np.array(img_list)
list_img_name.sort() 
# print(list_img_name[0])
 
with tf.Session() as sess:
 
     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
 
     for jj in range(0,len(list_img_name)):#len(list_img_name)
 
         target_path=image+list_img_name[jj]+'/'
         for root, dirs, files in os.walk(target_path):
             for file in files:
                 img_path = target_path+file
                 image_data = tf.gfile.FastGFile(img_path, 'rb').read()
                 fc_tensor = sess.graph.get_tensor_by_name('pool_3:0')
                 pool_1 = sess.run(fc_tensor,{'DecodeJpeg/contents:0': image_data})
                 pool_2 = pool_1[0,0,0,:]
                 img_path=img_path[:len(img_path)-4]
                 scio.savemat(img_path+'.mat', {"pool_2": pool_2})
                 pi= (jj/(len(list_img_name)-1))*100
                 print("%4.2f %%" % pi)

将向量拉直并遍历整个数据集

二、Pytorh版

import os.path
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
features_dir ='./features'# 存放特征的?件夹路径
data_list =[]
path='./FullFrame1'
for filename in os.listdir(path):
    transform1 = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
                         ])#转换成Tensor
print(filename)
img = Image.open(path +'/'+ filename)# 打开图?
img1 = transform1(img)# 对图?进?transform1的各种操作

# resnet18 = models.resnet18(pretrained = True)
inception_v3_feature_extractor = models.inception_v3(pretrained=True)# 导?inception_v3的预训练模型
inception_v3_feature_extractor.fc = nn.Linear(2048,2048)# 重新定义最后?层
torch.nn.init.eye(inception_v3_feature_extractor.fc.weight)#将?维tensor初始化为单位矩阵
for param in inception_v3_feature_extractor.parameters():
        param.requires_grad =False
# resnet152 = models.resnet152(pretrained = True)
# densenet201 = models.densenet201(pretrained = True)
print(img1.shape)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
# x = Variable(img1, requires_grad=False)
print(x.shape)
inception_v3_feature_extractor.training=False
y = inception_v3_feature_extractor(x)#报错原因在于这
    
y = y.data.numpy()
data_list.append(y)
data_npy = np.array(data_list)
print(data_npy.shape)
np.save('1.npy',data_list)

修改后,每个图片的特征保存为一个.txt文件(也可以是.csv文件):
在这里插入图片描述

import os.path
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image

def feature_extract(file,save_path):
  transform1 = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
                         ])#转换成Tensor
  img = Image.open(file)# 打开图?
  img1 = transform1(img)# 对图?进?transform1的各种操作
  
  #print(img1.shape)
  x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
  #x = Variable(img1, requires_grad=False)
  #print(x.shape)
  inception_v3_feature_extractor.training=False
  y = inception_v3_feature_extractor(x) 

  y = y.data.numpy()
  #print(y.shape)
  y = y.reshape(1,2048)
  np.savetxt(save_path + '.txt', y, fmt='%s')  # 保存特征文件为txt

def read_image(rootdir,save_path):
    list = os.listdir(rootdir) #列出文件夹下所有的目录与文件
    # print(list)
    # files = []
    #for i in range(0,len(list)):
    #   path = os.path.join(rootdir,list[i])
        # print(path)
        # subFiles = []
    for file in os.listdir(rootdir):
        # subFiles.append(file)
        savePath = os.path.join(save_path,file[:-4])
        #print(file)
        filename = os.path.join(rootdir,file)
        feature_extract(filename,savePath)
        # print("successfully saved "+ file[:-4] +".csv !")    # 保存特征文件为csv
        print("successfully saved " + file[:-4] + ".txt !")   # 保存特征文件为txt

if __name__ == '__main__':
  # resnet18 = models.resnet18(pretrained = True)
  inception_v3_feature_extractor = models.inception_v3(pretrained=True)# 导?inception_v3的预训练模型
  inception_v3_feature_extractor.fc = nn.Linear(2048,2048)# 重新定义最后?层
  torch.nn.init.eye(inception_v3_feature_extractor.fc.weight)#将?维tensor初始化为单位矩阵
  for param in inception_v3_feature_extractor.parameters():
    param.requires_grad = False
  # resnet152 = models.resnet152(pretrained = True)
  # densenet201 = models.densenet201(pretrained = True)
  print("Model has been onload !")
  #inception_v3_feature_extractor.summary()
  save_path ='/content/drive/MyDrive/ImageClassify_feature_extract/inception_V3_ck_feature'# 存放特征的?件夹路径
  root_dir='/content/drive/MyDrive/ImageClassify_feature_extract/ck+/data2'   # 图片路径
  read_image(root_dir,save_path)
  print("work has been done !")

参考链接:

https://wenku.baidu.com/view/993bc45a24284b73f242336c1eb91a37f111326b.html
来源:百度文库
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-05-24 18:07:23  更:2022-05-24 18:07:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/27 15:28:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计