最近学习了SSD算法,了解了其基本的实现思路,并通过SSD模型训练自己的模型。
基本环境
torch1.2.0 Pillow8.2.0 torchvision0.4.0 CUDA版本可查看自己电脑,这里使用CUDA10.0 visual studio 2019 scipy1.2.1 numpy1.17.0 matplotlib3.1.2 opencv_python4.1.2.30 tqdm4.60.0 h5py2.10.0
安装
建议创建一个虚拟环境,本文使用到的是在Pycharm环境下 打开pytorch的官方安装方法: https://pytorch.org/get-started/previous-versions/ 但是可以先进入: https://download.pytorch.org/whl/torch_stable.html 找到自己需要下载自己需要的即可。 找到自己的下载路径,然后再命令窗口定位,再使用 pip install +下载好的whl文件即可 再安装相关依赖包需要先激活环境,进行安装。 同时安装CUDA和visual studio 2019可参考网上教程,这里不细讲。
数据集的准备
本文使用VOC格式进行训练, 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中,文件格式为xml。图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中,格式为jpg,如下图所示。
数据集处理
整个项目的文件如下(里面包含一些个人测试的代码):
第一步需要运行voc_annotation.py,并更改其代码里面的一些参数(annotation_mode、classes_path、trainval_percent、train_percent、VOCdevkit_path都可以修改,但也可以只修改以下内容即可): 需要修改model_data文件里面的voc_classes.txt内容,例如本例中修改如下:
即可生成训练用的2007_train.txt以及2007_val.txt。
图片处理
本例统一输入进来的图片是300*300大小的3通道图片。
- 对输入进来的图片进行判断是否为RGB,如果不是则进行转RGB
- 对图像进行统一大小裁剪,为防止图片失真,在其添加上灰条。
- 对图片进行数据增强,通过翻转,随机选取等操作。
模型训练
训练文件train.py中也要修改部分参数 classes_path一定要对应自己的分类文件,以及自己权重文件的位置。经过多次epochs后,权值会生成在logs文件夹。
在训练开始前还需要更改其他py文件的内容: 在summary.py文件中: m=SSD300(7,‘vgg’).to(device)中7代表的是分类的个数,这里需要修改为2,因为只本例只分为了2类。下面(3,300,300)代表输入的是300*300大小的3通道图片。
运行train.py文件进行模型训练,若出现out of memory问题,可以减小每次训练的batch_size的大小。
模型预测
模型预测先要去修改ssd.py文件中的model_path(在自己保存权值的logs文件当中选取一个权值文件,放到model_data文件夹中,并修改下面的路径,其次classes_path也要进行对应的修改: 这里单独调用摄像头进行预测,相关代码如下所示:
import time
import cv2
import numpy as np
from PIL import Image
from ssd import SSD
#口罩识别模型
if __name__ == "__main__":
ssd = SSD()
video_path = 0
video_save_path = ""
video_fps = 25.0
# 指定测量fps的时候,图片检测的次数
test_interval = 100
capture=cv2.VideoCapture(video_path)
if video_save_path!="":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
ref, frame = capture.read()
if not ref:
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
fps = 0.0
while(True):
t1 = time.time()
# 读取某一帧
ref, frame = capture.read()
if not ref:
break
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(ssd.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video",frame)
if video_save_path!="":
out.write(frame)
if cv2.waitKey(10) & 0xff==ord('q'):
break
capture.release()
cv2.destroyAllWindows()
效果图如下:
未戴口罩
戴口罩
整体来说效果还是不错的。
后续
后面我又去找了其他数据集进行训练,对其进行不同的图片处理以及模型的改进,达到的效果还不错。但是图片格式为jepg的,因此在代码当中添加了对图片类型的判断,但是若不添加代码,则需要更改文件get_map.py中:
后缀为对应的图片类型。 其次自己写了一个简易版的GUI界面,使其输出各坐标,以及害虫的分类,相关代码如下:
import time
import tkinter as tk
from PIL import Image,ImageTk
from tkinter import filedialog
import cv2
import numpy as np
from ssd import SSD
import prettytable as pt
import xml.etree.ElementTree as ET
class GuiDetect:
def __init__(self):
self.ssd = SSD()
#crop指定了是否在单张图片预测后对目标进行截取
self.crop = False
#图片检测的次数
self.test_interval = 100
#选择图片路径
def selectPath(self):
path_ = filedialog.askopenfilename()
self.path.set(path_)
#统计图片上害虫的数量
# def count_num(self):
# img = cv2.imread(self.path.get())
# #裁剪图像
# img= img[200:1125, 135:1088]
# # 图像灰度化
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# # 图像二值化
# ret, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
# #对图像进行开运算和闭运算的清洗
# kernel = np.ones((5,5),np.uint8)
# #开运算
# opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# #闭运算
# closing=cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
# #统计轮廓
# cons, hei = cv2.findContours(closing, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# # 轮廓过滤以及绘制
# count_area=[]
# for i in range(len(cons)):
# # 筛掉面积过小的轮廓
# area = cv2.contourArea(cons[i])
# if area < 300 :
# continue
# elif area>11000 and area<800000 :
# count_area.append('0')
# count_area.append(area)
# return len(count_area)-1
#对图片进行预测,并输出结果
def predict(self):
mo_img =self.path.get()
image_mo = Image.open(mo_img)
b=Image.open(mo_img)
#检测
r_image = self.ssd.detect_image(image_mo, crop =self.crop)
im=r_image[0]
#对图片进行下采样,以一定大小放入界面
im1=b.resize((440,500),Image.ANTIALIAS)
im11=im.resize((440,500),Image.ANTIALIAS)
image= ImageTk.PhotoImage(im1)
image1=ImageTk.PhotoImage(im11)
#配置原图
self.panel.configure(image=image)
self.panel.image=image
time.sleep(1)
# 获取检测的图片
self.panel1.configure(image=image1)
self.panel1.image=image1
label=r_image[1]
box=r_image[2]
self.create_table(label, box)
#将输出结果以表格格式输出
def create_table(self,label,box):
#创建表格
tb=pt.PrettyTable()
#写一行表头
tb.field_names=['序号','类别','置信度','左上角坐标','右下角坐标']
count=1#定义一个序号从1开始可以知道检测到的害虫个数
save_count=[]
for i in range(len(label)):
mql_label=label[i].split()
#获取目标值并将byte类型进行转化str类型
#分类标签
mo_label=mql_label[0].decode(encoding='utf-8', errors='strict')
#置信度
mo_score=mql_label[1].decode(encoding='utf-8', errors='strict')
zb=box[i]
#左上角
mo_zsj=tuple(zb[0:2])
#右下角
mo_yxj=tuple(zb[-2:])
#添加到表格当中
tb.add_row([count,mo_label,mo_score,mo_zsj,mo_yxj])
count+=1
save_count.append(count)
#获取检测图片的地址
mql_path='检测图片地址为:\n{} \n'.format(self.path.get())
self.mql_text.insert(tk.INSERT,mql_path)
# count_number=self.count_num()
# el=count_number-save_count[-2]
#插入统计的害虫个数
mql_count='\n检测出的个数:{} \n'.format(save_count[-2])
self.mql_text.insert(tk.INSERT,mql_count)
#插入表格
self.mql_text.insert(tk.INSERT,tb)
# print(tb)
# print(save_count[-2])#检测的数量
sign='\n - - - - - - - - - - - - - - - - - - - - - - - - - - - \n\n'
self.mql_text.insert(tk.INSERT,sign)
def main(self):
self.root=tk.Tk()
self.root.title('害虫检测系统')
self.root.bg='black'
width,height=self.root.maxsize()
self.root.geometry("{}x{}".format(width, height))
tk.Frame(self.root)
#设置打开图片的路径(图片暂时为jprg格式)
#Fm1-----------设置内容标题
fm1=tk.Frame(bg='black')
tk.Label(fm1,text='害虫检测系统',font=('微雅软黑',50),fg='white',bg='black').pack()
fm1.pack(side=tk.TOP,expand=tk.YES,fill='x',pady=20)
#fm2------------选择路径、检测、结果标签设置
fm2=tk.Frame(bg='black')
fm2_left=tk.Frame(fm2,bg='black')
fm2_left_top=tk.Frame(fm2_left,bg='black')
fm2_left_bottom=tk.Frame(fm2_left,bg='black')
#-------------------------------------------------------------------------------
#选择图片路径
self.path = tk.StringVar(self.root,value='')
mtext=tk.Entry(fm2_left_top,font=('微软雅黑',15),width='100',fg='#FF4081',textvariable=self.path)
preButton=tk.Button(fm2_left_top,text='选择路径',bg='#BA55D3',fg='white',font=('微软雅黑',15),width='16',command=self.selectPath)
preButton.pack(side=tk.LEFT,pady=5)
#检测按钮
tk.Button(fm2_left_top,text='检测',bg='#FF4081',fg='white',font=('微软雅黑',15),width='16',command=self.predict).pack(side=tk.RIGHT)
mtext.pack(side=tk.LEFT,fill='y',padx=10,pady=5)
fm2_left_top.pack(side=tk.TOP,fill='x')
#------------------------------------------------------------------------------
#设置图片保存位置
tLabel=tk.Label(fm2_left_bottom,text='原图',bg='#4682B4',fg='white',font=('微软雅黑',15),width='20')
tLabel.pack(side=tk.LEFT)
#检测结果图片
tLabel1=tk.Label(fm2_left_bottom,text='检测图片结果',bg='#32CD32',fg='white',font=('微软雅黑',15),width='25')
tLabel1.pack(side=tk.LEFT,padx=230)
tLabel2=tk.Label(fm2_left_bottom,text='打印结果',bg='#008B8B',fg='white',font=('微软雅黑',15),width='20')
tLabel2.pack(side=tk.RIGHT,padx=20)
fm2_left_bottom.pack(side=tk.LEFT,pady=10,fill='x')
fm2_left.pack(side=tk.LEFT,padx=60,expand=tk.YES,fill='x')
fm2.pack(side=tk.TOP,expand=tk.YES,fill="x")
#--------------------------------------------------------------------------------
# fm3,#将其分为三个部分,分别是原图,检测结果图以及print的结果图
#-----------------------------------------------------------------------
fm3=tk.Frame(bg='black')
# fm3_left=tk.Frame(fm3,width=100, height=80, relief=tk.GROOVE, borderwidth=5)
fm3_left=tk.Frame(fm3,width=100, height=80,relief=tk.GROOVE)
fm3_right=tk.Frame(fm3,width=100, height=80, relief=tk.RAISED, borderwidth=5)
fm3_right_top=tk.Frame(fm3,width=100, height=80, relief=tk.RIDGE, borderwidth=5)
#------------------------------------------------------------------------------
load = Image.open('./img/1.jpg')
im=load.resize((440,500),Image.ANTIALIAS)
initIamge= ImageTk.PhotoImage(im)
self.panel = tk.Label(fm3_left,image=initIamge)
self.panel.image = initIamge
self.panel.pack(side=tk.LEFT)
self.panel1=tk.Label(fm3_right,image=initIamge)
self.panel1.image = initIamge
self.panel1.pack(side=tk.LEFT)
## 添加文本框
self.mql_text=tk.Text(fm3_right_top,width=500,height=500)
self.mql_text.pack(side=tk.LEFT,padx=10,pady=10)
#---------------------------------------------------
fm3_left.pack(side=tk.LEFT,padx=5, pady=5)
fm3_right.pack(side=tk.LEFT, padx=5, pady=5)
fm3_right_top.pack(side=tk.LEFT, padx=5, pady=5)
fm3.pack(side=tk.LEFT,expand=tk.YES,fill='both',pady=10,padx=10)
#--------------------------------------------------------------------
#-------------------------------------------------------------------------------
self.root.mainloop()
if __name__=='__main__':
self=GuiDetect()
self.main()
效果图如下:
但在模型对小目标检测方面还是存在一点问题,正在尝试提高其精度。
建议还是要先去学习下SSD模型的基本算法思路,理解起来更加清楚、明白.
|