视频物体检测测试
SSD训练商品数据
主程序
GPU版本运行
import pickle
from utils.detection_generate import Generator
from utils.ssd_utils import BBoxUtility
from nets.ssd_net import SSD300
from utils.ssd_losses import MuliboxLoss
from tensorflow.python import keras
class SSDTrain(object):
def __init__(self,num_classes=9,input_shape=(300,300,3),epochs=30):
"""
初始化网络指定一些参数,训练数据类别,图片需要指定模型输入大小,迭代次数
:param num_classes:
"""
self.num_classes = num_classes
self.batch_size = 32
self.input_shape = input_shape
self.epochs = epochs
self.gt_path = "./datasets/commodity_gt.pkl"
self.image_path = "./datasets/commodity/JPEGImages/"
prior = pickle.load(open("./datasets/prior_boxes_ssd300.pkl","rb"))
self.bbox_util = BBoxUtility(self.num_classes,prior)
self.pre_trained = "./ckpt/pre_trained/weights_SSD300.hdf5"
self.model = SSD300(self.input_shape,num_classes=self.num_classes)
def get_detefction_data(self):
"""
获取检测的迭代数据
:return:
"""
gt = pickle.load(open(self.gt_path,'rb'))
print(gt)
name_keys = sorted(gt.keys())
number = int(round(0.8*len(name_keys)))
train_keys = name_keys[:number]
val_keys = name_keys[number:]
bbox_util = 0
gen = Generator(gt,self.bbox_util,self.batch_size,self.image_path,
train_keys,val_keys,(self.input_shape[0],self.input_shape[1]),do_crop=False)
return gen
def init_model_param(self):
"""
初始化网络模型参数,指定微调的时候,训练部分
:return:
"""
self.model.load_weights(self.pre_trained,by_name = True)
freeze = ['input_1','conv1_1','conv1_2','pool1',
'conv2_1','conv2_2','pool2',
'conv3_1','conv3_2','conv3_3','pool3']
for L in self.model.layers:
if L.name in freeze:
L.trainable = False
return None
def compile(self):
"""编译模型
SSD网络的损失函数计算Multiboxloss的compute_loss
"""
self.model.compile(optimizer = keras.optimizers.Adam(),
loss = MuliboxLoss(self.num_classes).compute_loss)
if __name__ == '__name__':
ssd = SSDTrain(num_classes=9)
gen = ssd.get_detefction_data()
|