IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于SSD目标检测模型的人脸口罩识别 -> 正文阅读

[人工智能]基于SSD目标检测模型的人脸口罩识别

最近学习了SSD算法,了解了其基本的实现思路,并通过SSD模型训练自己的模型。

基本环境

torch1.2.0
Pillow
8.2.0
torchvision0.4.0
CUDA版本可查看自己电脑,这里使用CUDA
10.0
visual studio 2019
scipy1.2.1
numpy
1.17.0
matplotlib3.1.2
opencv_python
4.1.2.30
tqdm4.60.0
h5py
2.10.0

安装

建议创建一个虚拟环境,本文使用到的是在Pycharm环境下
打开pytorch的官方安装方法:
https://pytorch.org/get-started/previous-versions/
但是可以先进入:
https://download.pytorch.org/whl/torch_stable.html
找到自己需要下载自己需要的即可。
在这里插入图片描述
找到自己的下载路径,然后再命令窗口定位,再使用
pip install +下载好的whl文件即可
再安装相关依赖包需要先激活环境,进行安装。
同时安装CUDA和visual studio 2019可参考网上教程,这里不细讲。

数据集的准备

本文使用VOC格式进行训练,
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中,文件格式为xml。图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中,格式为jpg,如下图所示。
在这里插入图片描述

数据集处理

整个项目的文件如下(里面包含一些个人测试的代码):
在这里插入图片描述

第一步需要运行voc_annotation.py,并更改其代码里面的一些参数(annotation_mode、classes_path、trainval_percent、train_percent、VOCdevkit_path都可以修改,但也可以只修改以下内容即可):
在这里插入图片描述
需要修改model_data文件里面的voc_classes.txt内容,例如本例中修改如下:
在这里插入图片描述

即可生成训练用的2007_train.txt以及2007_val.txt。

图片处理

本例统一输入进来的图片是300*300大小的3通道图片。

  1. 对输入进来的图片进行判断是否为RGB,如果不是则进行转RGB
  2. 对图像进行统一大小裁剪,为防止图片失真,在其添加上灰条。
  3. 对图片进行数据增强,通过翻转,随机选取等操作。

模型训练

训练文件train.py中也要修改部分参数
在这里插入图片描述
classes_path一定要对应自己的分类文件,以及自己权重文件的位置。经过多次epochs后,权值会生成在logs文件夹。

在训练开始前还需要更改其他py文件的内容:
在summary.py文件中:
在这里插入图片描述
m=SSD300(7,‘vgg’).to(device)中7代表的是分类的个数,这里需要修改为2,因为只本例只分为了2类。下面(3,300,300)代表输入的是300*300大小的3通道图片。

运行train.py文件进行模型训练,若出现out of memory问题,可以减小每次训练的batch_size的大小。

模型预测

模型预测先要去修改ssd.py文件中的model_path(在自己保存权值的logs文件当中选取一个权值文件,放到model_data文件夹中,并修改下面的路径,其次classes_path也要进行对应的修改:
在这里插入图片描述
这里单独调用摄像头进行预测,相关代码如下所示:

import time

import cv2
import numpy as np
from PIL import Image

from ssd import SSD


#口罩识别模型
if __name__ == "__main__":
   ssd = SSD()
   video_path      = 0
   video_save_path = ""
   video_fps       = 25.0
   # 指定测量fps的时候,图片检测的次数
   test_interval = 100
   capture=cv2.VideoCapture(video_path)
   if video_save_path!="":
       fourcc = cv2.VideoWriter_fourcc(*'XVID')
       size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
       out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

   ref, frame = capture.read()
   if not ref:
       raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")

   fps = 0.0
   while(True):
       t1 = time.time()
       # 读取某一帧
       ref, frame = capture.read()
       if not ref:
           break
       # 格式转变,BGRtoRGB
       frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
       # 转变成Image
       frame = Image.fromarray(np.uint8(frame))
       # 进行检测
       frame = np.array(ssd.detect_image(frame))
       # RGBtoBGR满足opencv显示格式
       frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
       
       fps  = ( fps + (1./(time.time()-t1)) ) / 2
       print("fps= %.2f"%(fps))
       frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
       
       cv2.imshow("video",frame)
       if video_save_path!="":
           out.write(frame)

       if cv2.waitKey(10) & 0xff==ord('q'):
           break
   capture.release()
   cv2.destroyAllWindows()

效果图如下:

未戴口罩
在这里插入图片描述

戴口罩
在这里插入图片描述

整体来说效果还是不错的。

后续

后面我又去找了其他数据集进行训练,对其进行不同的图片处理以及模型的改进,达到的效果还不错。但是图片格式为jepg的,因此在代码当中添加了对图片类型的判断,但是若不添加代码,则需要更改文件get_map.py中:
在这里插入图片描述

后缀为对应的图片类型。
其次自己写了一个简易版的GUI界面,使其输出各坐标,以及害虫的分类,相关代码如下:

import time
import tkinter as tk
from PIL import Image,ImageTk
from tkinter import filedialog
import cv2
import numpy as np
from ssd import SSD
import prettytable as pt
import xml.etree.ElementTree as ET

class GuiDetect:
   def __init__(self):
       self.ssd = SSD()
       #crop指定了是否在单张图片预测后对目标进行截取
       self.crop = False
       #图片检测的次数
       self.test_interval = 100
   
   #选择图片路径
   def selectPath(self):
               path_ = filedialog.askopenfilename()
               self.path.set(path_)
   
   #统计图片上害虫的数量    
   # def count_num(self):    
   #     img = cv2.imread(self.path.get())  
   #     #裁剪图像
   #     img= img[200:1125, 135:1088]
   #     # 图像灰度化
   #     gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
   #     # 图像二值化 
   #     ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)     
   #     #对图像进行开运算和闭运算的清洗
   #     kernel = np.ones((5,5),np.uint8)
   #     #开运算
   #     opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
   #     #闭运算
   #     closing=cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
       
   #     #统计轮廓       
   #     cons, hei = cv2.findContours(closing, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
   #     # 轮廓过滤以及绘制
   #     count_area=[]
   #     for i in range(len(cons)):
   #         # 筛掉面积过小的轮廓        
   #         area = cv2.contourArea(cons[i])
   #         if area < 300 :
   #             continue
   #         elif area>11000 and area<800000 :
   #             count_area.append('0')

   #         count_area.append(area)
   #     return len(count_area)-1
    

   #对图片进行预测,并输出结果
   def predict(self):
       mo_img =self.path.get()     
       image_mo = Image.open(mo_img)
       b=Image.open(mo_img)
       #检测
       r_image = self.ssd.detect_image(image_mo, crop =self.crop)
       im=r_image[0]
       #对图片进行下采样,以一定大小放入界面
       im1=b.resize((440,500),Image.ANTIALIAS)
       im11=im.resize((440,500),Image.ANTIALIAS)
       image= ImageTk.PhotoImage(im1)  
       image1=ImageTk.PhotoImage(im11)
       #配置原图
       self.panel.configure(image=image)
       self.panel.image=image
       time.sleep(1)
       # 获取检测的图片
       self.panel1.configure(image=image1)
       self.panel1.image=image1
       label=r_image[1]
       box=r_image[2]
       self.create_table(label, box)
   
   #将输出结果以表格格式输出
   def create_table(self,label,box):
       #创建表格
       tb=pt.PrettyTable()
       #写一行表头
       tb.field_names=['序号','类别','置信度','左上角坐标','右下角坐标']
       count=1#定义一个序号从1开始可以知道检测到的害虫个数
       save_count=[]
       for i in range(len(label)):
           mql_label=label[i].split()
           #获取目标值并将byte类型进行转化str类型
           #分类标签
           mo_label=mql_label[0].decode(encoding='utf-8', errors='strict')
           #置信度
           mo_score=mql_label[1].decode(encoding='utf-8', errors='strict')
           zb=box[i]
           #左上角
           mo_zsj=tuple(zb[0:2])
           #右下角
           mo_yxj=tuple(zb[-2:])
           #添加到表格当中
           tb.add_row([count,mo_label,mo_score,mo_zsj,mo_yxj])
           count+=1
           save_count.append(count)
           
       #获取检测图片的地址
       mql_path='检测图片地址为:\n{} \n'.format(self.path.get())
       self.mql_text.insert(tk.INSERT,mql_path)
       # count_number=self.count_num()
       # el=count_number-save_count[-2]
       #插入统计的害虫个数
       mql_count='\n检测出的个数:{} \n'.format(save_count[-2])
       self.mql_text.insert(tk.INSERT,mql_count)
       #插入表格
       self.mql_text.insert(tk.INSERT,tb)
       # print(tb)
       # print(save_count[-2])#检测的数量
       sign='\n - - - - - - - - - - - - - - - - - - - - - - - - - - - \n\n'
       self.mql_text.insert(tk.INSERT,sign)
       
       
   def main(self):
       self.root=tk.Tk()
       self.root.title('害虫检测系统')
       self.root.bg='black'
       width,height=self.root.maxsize()
       self.root.geometry("{}x{}".format(width, height))
       tk.Frame(self.root)   
       #设置打开图片的路径(图片暂时为jprg格式)
       
       #Fm1-----------设置内容标题
       fm1=tk.Frame(bg='black')
       tk.Label(fm1,text='害虫检测系统',font=('微雅软黑',50),fg='white',bg='black').pack()
       fm1.pack(side=tk.TOP,expand=tk.YES,fill='x',pady=20)
       
       
       #fm2------------选择路径、检测、结果标签设置
       fm2=tk.Frame(bg='black')
       fm2_left=tk.Frame(fm2,bg='black')
       fm2_left_top=tk.Frame(fm2_left,bg='black')
       fm2_left_bottom=tk.Frame(fm2_left,bg='black')
       
       #-------------------------------------------------------------------------------
       #选择图片路径        
       self.path = tk.StringVar(self.root,value='')
       mtext=tk.Entry(fm2_left_top,font=('微软雅黑',15),width='100',fg='#FF4081',textvariable=self.path)
       preButton=tk.Button(fm2_left_top,text='选择路径',bg='#BA55D3',fg='white',font=('微软雅黑',15),width='16',command=self.selectPath)
       preButton.pack(side=tk.LEFT,pady=5)
       #检测按钮
       tk.Button(fm2_left_top,text='检测',bg='#FF4081',fg='white',font=('微软雅黑',15),width='16',command=self.predict).pack(side=tk.RIGHT)      
       mtext.pack(side=tk.LEFT,fill='y',padx=10,pady=5)
       fm2_left_top.pack(side=tk.TOP,fill='x')
               
       #------------------------------------------------------------------------------
       #设置图片保存位置
       tLabel=tk.Label(fm2_left_bottom,text='原图',bg='#4682B4',fg='white',font=('微软雅黑',15),width='20')
       tLabel.pack(side=tk.LEFT)
       
       #检测结果图片
       tLabel1=tk.Label(fm2_left_bottom,text='检测图片结果',bg='#32CD32',fg='white',font=('微软雅黑',15),width='25')
       tLabel1.pack(side=tk.LEFT,padx=230)
       
       tLabel2=tk.Label(fm2_left_bottom,text='打印结果',bg='#008B8B',fg='white',font=('微软雅黑',15),width='20')
       tLabel2.pack(side=tk.RIGHT,padx=20)

       fm2_left_bottom.pack(side=tk.LEFT,pady=10,fill='x')
       fm2_left.pack(side=tk.LEFT,padx=60,expand=tk.YES,fill='x')
       fm2.pack(side=tk.TOP,expand=tk.YES,fill="x")
       
       #--------------------------------------------------------------------------------
       # fm3,#将其分为三个部分,分别是原图,检测结果图以及print的结果图
       #-----------------------------------------------------------------------
       fm3=tk.Frame(bg='black')
       # fm3_left=tk.Frame(fm3,width=100, height=80, relief=tk.GROOVE, borderwidth=5)
       fm3_left=tk.Frame(fm3,width=100, height=80,relief=tk.GROOVE)
       
       fm3_right=tk.Frame(fm3,width=100, height=80, relief=tk.RAISED, borderwidth=5)
       fm3_right_top=tk.Frame(fm3,width=100, height=80, relief=tk.RIDGE, borderwidth=5)
       #------------------------------------------------------------------------------
       
       load = Image.open('./img/1.jpg') 
       im=load.resize((440,500),Image.ANTIALIAS)
       initIamge= ImageTk.PhotoImage(im)  
       self.panel = tk.Label(fm3_left,image=initIamge)  
       self.panel.image = initIamge  
       self.panel.pack(side=tk.LEFT)
       
       
       
       
       self.panel1=tk.Label(fm3_right,image=initIamge)  
       self.panel1.image = initIamge 
       self.panel1.pack(side=tk.LEFT) 
       
       ## 添加文本框
       self.mql_text=tk.Text(fm3_right_top,width=500,height=500)
       self.mql_text.pack(side=tk.LEFT,padx=10,pady=10)
       #---------------------------------------------------       
       fm3_left.pack(side=tk.LEFT,padx=5, pady=5)
       fm3_right.pack(side=tk.LEFT, padx=5, pady=5)
       fm3_right_top.pack(side=tk.LEFT, padx=5, pady=5)
       fm3.pack(side=tk.LEFT,expand=tk.YES,fill='both',pady=10,padx=10)
       #--------------------------------------------------------------------
       #-------------------------------------------------------------------------------        
       self.root.mainloop()

if __name__=='__main__':
   self=GuiDetect()
   self.main()

效果图如下:

在这里插入图片描述
但在模型对小目标检测方面还是存在一点问题,正在尝试提高其精度。

建议还是要先去学习下SSD模型的基本算法思路,理解起来更加清楚、明白.

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-07 11:11:01  更:2022-05-07 11:14:16 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 7:32:02-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码