注:请读者不要害怕,不要看到下面感觉有这么的知识点要学,其实这些知识点并没有太大的影响,你也可以直接阅读本文的代码和最后给出的教学视频学习。但是关于YOLO-V3大概的网络结构,读者可以点击下面的链接
https://mydreamambitious.blog.csdn.net/article/details/125503943
1.之前的一篇文章Opencv中使用Tracker实现物体跟踪
https://mydreamambitious.blog.csdn.net/article/details/125499463
2.YOLO-V3实时检测实现(opencv+python实现)——改进——>更加的易懂
https://blog.csdn.net/Keep_Trying_Go/article/details/125520487
3.前置知识点(必备)
(1)前一篇的对象跟踪
https://mydreamambitious.blog.csdn.net/article/details/125471084
(2)Opencv中绘制直线,矩形,圆,椭圆,多边形(包括多边形填充),绘制文本
https://mydreamambitious.blog.csdn.net/article/details/125392536
(3)Opencv中基础的知识点
https://mydreamambitious.blog.csdn.net/article/details/125351256
(4)Opencv实现图像的基本变换
https://mydreamambitious.blog.csdn.net/article/details/125428402
4.YOLO-V3权重文件,网络配置文件以及coco.names文件下载
(1)YOLOV3权重文件下载
https://pjreddie.com/darknet/yolo/
(2)YOLOV3类别文件下载
https://github.com/pjreddie/darknet/blob/master/data/coco.names
(3)YOLO.cfg配置文件下载
https://github.com/pjreddie/darknet
5.代码实战
(1)网络文件(.cfg)和权重文件的读取
#读取YOLO-V3权重文件和网络配置文件
net=cv2.dnn.readNet(model='dnn_model/yolov3.weights',config='dnn_model/yolov3.cfg')
(2)类别文件的读取
#设置置信度阈值和非极大值抑制的阈值
Confidence_thresh=0.2
Nms_thresh=0.35
#读取coco.names文件中的类别
with open('dnn_model/coco.names','r') as fp:
classes=fp.read().splitlines()
(3)网络模型的加载
#yolo-v3检测
def detect(frame):
#获取网络模型
model=cv2.dnn_DetectionModel(net)
#设置网络的输入参数
model.setInputParams(scale=1/255,size=(416,416))
#进行预测
class_id,scores,boxes=model.detect(frame,confThreshold=Confidence_thresh,
nmsThreshold=Nms_thresh)
#返回预测的类别和坐标
return class_id,scores,boxes
(4)得到预测的结果
#进行预测
class_ids,scores,boxes=detect(frame)
(5)根据预测的结果绘制框
#检测图像中的物体
(class_id,scores,boxex)=detect(frame)
for box in boxex:
(x,y,w,h)=box
#检测到物体的中心点
cx=int(x+w/2)
cy=int(y+h/2)
center_points_cur_frame.append((cx,cy))
cv2.rectangle(img=frame,pt1=(x,y),pt2=(x+w,y+h),color=(0,255,0),thickness=2)
(6)判断是否为跟踪的同一个物体
注:读者可能一开始在读下面这段代码时,可能感觉很多,建议读者先看一下最后给出的教学视频。
#开始的时候只比较当前帧和前一帧的距离
if count<=2:
for center_cur in center_points_cur_frame:
for center_prev in center_points_prev_frame:
#计算前一帧和当前帧的物体中心距离,从而判断是否为同一个物体(为了进行物体跟踪)
distance=math.hypot(center_prev[0]-center_cur[0],center_prev[1]-center_cur[1])
#判断前一帧和后一帧中的物体之间距离是否大于给定的阈值,,如果超过阈值则不是同一个物体
if distance<10:
tracker_object[track_id]=center_cur
track_id+=1
else:
#比较字典中保存的物体和当前帧的距离
bucket=tracker_object.copy()
center_points_cur_frame_copy=center_points_cur_frame.copy()
for object_id, center_prev in bucket.items():
#使用object_exist判断当前检测的物体是否存入字典(存入字典之后则不需要进行计算)
object_exist=False
for center_cur in center_points_cur_frame_copy:
# 计算已保存在字典中的物体和当前帧的物体中心距离,从而判断是否为同一个物体(为了进行物体跟踪)
distance = math.hypot(center_prev[0] - center_cur[0], center_prev[1] - center_cur[1])
# 判断已保存在字典中的物体和后一帧中的物体之间距离是否大于给定的阈值,,如果超过阈值则不是同一个物体
if distance < 10:
tracker_object[track_id] = center_cur
object_exist=True
if center_cur in center_points_cur_frame:
center_points_cur_frame.remove(center_cur)
continue
#如果遍历完当前的所有物体,都没有找到字典中与之匹配的物体,则删除id
if not object_exist:
tracker_object.pop(object_id)
#将其中新增加的物体添加到字典中
for center_cur in center_points_cur_frame:
tracker_object[track_id]=center_cur
track_id+=1
(7)整体代码
import os
import cv2
import math
import time
#基于yolo-v3的目标检测
#yolo-v3检测
def detect(frame):
# 读取YOLO-V3权重文件和网络配置文件
net = cv2.dnn.readNet(model='dnn_model/yolov3.weights', config='dnn_model/yolov3.cfg')
# 设置置信度阈值和非极大值抑制的阈值
Confidence_thresh = 0.2
Nms_thresh = 0.35
#获取网络模型
model=cv2.dnn_DetectionModel(net)
#设置网络的输入参数
model.setInputParams(scale=1/255,size=(416,416))
#进行预测
class_id,scores,boxes=model.detect(frame,confThreshold=Confidence_thresh,
nmsThreshold=Nms_thresh)
#返回预测的类别和坐标
return class_id,scores,boxes
def Object_Track():
# 打开摄像头
cap = cv2.VideoCapture('video/los_angeles.mp4')
# 计算视频帧数
count = 0
# 保存前一帧检测的物体中心点
center_points_prev_frame = []
# 创建跟踪对象
tracker_object = {}
track_id = 0
while cap.isOpened():
startime=time.time()
ret,frame=cap.read()
count+=1
# 保存当前帧检测的物体中心点
center_points_cur_frame = []
if ret==False:
break
#按图像的比例缩放图片
height,width,channel=frame.shape
# height_=int(height*(750/width))
# frame=cv2.resize(src=frame,dsize=(750,height_))
#检测图像中的物体
(class_id,scores,boxex)=detect(frame)
for box in boxex:
(x,y,w,h)=box
#检测到物体的中心点
cx=int(x+w/2)
cy=int(y+h/2)
center_points_cur_frame.append((cx,cy))
cv2.rectangle(img=frame,pt1=(x,y),pt2=(x+w,y+h),color=(0,255,0),thickness=2)
#开始的时候只比较当前帧和前一帧的距离
if count<=2:
for center_cur in center_points_cur_frame:
for center_prev in center_points_prev_frame:
#计算前一帧和当前帧的物体中心距离,从而判断是否为同一个物体(为了进行物体跟踪)
distance=math.hypot(center_prev[0]-center_cur[0],center_prev[1]-center_cur[1])
#判断前一帧和后一帧中的物体之间距离是否大于给定的阈值,,如果超过阈值则不是同一个物体
if distance<10:
tracker_object[track_id]=center_cur
track_id+=1
else:
#比较字典中保存的物体和当前帧的距离
bucket=tracker_object.copy()
center_points_cur_frame_copy=center_points_cur_frame.copy()
for object_id, center_prev in bucket.items():
#使用object_exist判断当前检测的物体是否存入字典(存入字典之后则不需要进行计算)
object_exist=False
for center_cur in center_points_cur_frame_copy:
# 计算已保存在字典中的物体和当前帧的物体中心距离,从而判断是否为同一个物体(为了进行物体跟踪)
distance = math.hypot(center_prev[0] - center_cur[0], center_prev[1] - center_cur[1])
# 判断已保存在字典中的物体和后一帧中的物体之间距离是否大于给定的阈值,,如果超过阈值则不是同一个物体
if distance < 10:
tracker_object[track_id] = center_cur
object_exist=True
if center_cur in center_points_cur_frame:
center_points_cur_frame.remove(center_cur)
continue
#如果遍历完当前的所有物体,都没有找到字典中与之匹配的物体,则删除id
if not object_exist:
tracker_object.pop(object_id)
#将其中新增加的物体添加到字典中
for center_cur in center_points_cur_frame:
tracker_object[track_id]=center_cur
track_id+=1
for object_id,center_cur in tracker_object.items():
cv2.circle(img=frame,center=center_cur,radius=3,color=(0,0,255),thickness=-1)
cv2.putText(img=frame,text=str(object_id),org=(center_cur[0],center_cur[1]-10),fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1.0,color=(0,0,255),thickness=1)
# print('track_object: {}'.format(tracker_object))
# print('center_prev: {}'.format(center_points_prev_frame))
# print('center_cur: {}'.format(center_points_cur_frame))
#显示FPS
endtime=time.time()
FPS=1/(endtime-startime)
cv2.putText(img=frame,text='FPS: '+str(int(FPS)),org=(20,height-40),fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1.0,color=(0,255,0),thickness=2)
cv2.imshow('img',frame)
#深度拷贝
center_points_prev_frame=center_points_cur_frame.copy()
key=cv2.waitKey(1)
if key==27:
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
print('Pycharm')
Object_Track()
(8)跟踪结果
视频教程:https://b23.tv/2D2RQ9L
|