记录一下科研训练 用的是EAST(检测)+CRNN(识别)
2022.10.27 参考了很多文章、文献、博客。。。 如有侵权,速速联系我== 有两个坑未填
East算法原理
典型的文本检测模型一般是会分多个阶段(multi-stage)进行,在训练时需要把文本检测切割成多个阶段(stage)来进行学习,这种把完整文本行先分割检测再合并的方式,既影响了文本检测的精度又非常耗时,对于文本检测任务上中间过程处理得越多可能效果会越差。 EAST(An Efficient and Accurate Scene Text Detector)是一种高效的文本检测方法,一般的深度学习检测方法通常都需要很多中间步骤,这样的话在训练期间就要对多个阶段进行调优,势必会非常消耗时间并且会影响最后的检测结果,而East框架则消除了许多的中间步骤直接对文本行进行预测,实现端到端文本检测,优雅简洁,检测的准确性和速度都有了进一步的提升。 其中,abcd是几种常见的文本检测过程,典型的检测过程包括候选框提取、候选框过滤、bouding box回归、候选框合并等阶段,中间过程比较冗长。而e即是本文介绍的EAST模型检测过程,从上图可看出,其过程简化为只有FCN阶段(全卷积网络)、NMS阶段(非极大抑制),中间过程大大缩减,而且输出结果支持文本行、单词的多个角度检测,既高效准确,又能适应多种自然应用场景。 在ICDAR2015等公开数据集上的实验均证明EAST算法在精度和效率方面在当时取得了相当不错的成绩。因为EAST算法具有结构简洁、性能较好、输出的文本框也比较适合路牌场景中文本区域的检测等优点,所以选择EAST算法作为本文文本检测算法的基准。 EAST模型的网络结构分为特征提取层、特征融合层、输出层三大部分。 1、特征提取层 基于PVANet(一种目标检测的模型)作为网络结构的骨干,分别从stage1,stage2,stage3,stage4的卷积层抽取出特征图,卷积层的尺寸依次减半,但卷积核的数量依次增倍,这是一种“金字塔特征网络”(FPN,feature pyramid network)的思想。通过这种方式,可抽取出不同尺度的特征图,以实现对不同尺度文本行的检测。 2、特征融合层 将前面抽取的特征图按一定的规则进行合并,这里的合并规则采用了U-net方法:
- 特征提取层中抽取的最后一层的特征图(f1)被最先送入unpooling层,将图像放大1倍
- 接着与前一层的特征图(f2)串起来(concatenate),然后依次作卷积核大小为1x1,3x3的卷积
- 对f3,f4重复以上过程,而卷积核的个数逐层递减,依次为128,64,32
- 最后经过32核,3x3卷积后将结果输出到“输出层”
3、输出层 最终输出以下5部分的信息,分别是: score map:检测框的置信度,1个参数; text boxes:检测框的位置(x, y, w, h),4个参数; text rotation angle:检测框的旋转角度,1个参数; text quadrangle coordinates:任意四边形检测框的位置坐标,(x1, y1), (x2, y2), (x3, y3), (x4, y4),8个参数。
East损失函数
总的损失函数如下所示:
其中,表示分割图像背景和图像文本的分类损失,文本区域所在的部分表示1,非文本区域的背景部分表示0,即像素点的分类损失。L_g表示对应文本区域的像素点所组成的矩形框和矩形框角度的回归损失。λ_g表示两个损失之间的相关性,为了显示两个损失同等重要,将λ_g设置为1。 为了简化训练过程,分类损失使用平衡的交叉熵,公式如下:
其中表示置信度的预测值,Y ?表示置信度的真实值,Y^*参数是调制系数,参数β是调制系数,主要用来控制正负样本之间的比例。计算公式为:
令L_AABB表示回归损失,旋转角度损失用L_θ表示:
其中,预测出来的文本倾斜角度用(θ,) ?表示,而文本矩形框真实的倾斜角度则用θ^*表示。让AABB表示从像素位置到文本矩形的上下左右4个边界的距离,令L_g为回归损失和旋转角度损失加权和,合称为几何损失,计算公式如下:
East算法代码
icdar.py数据预处理
def generator(input_size=512, batch_size=32,
background_ratio=3./8,
random_scale=np.array([0.5, 1, 2.0, 3.0]),
vis=False):
image_list = np.array(get_images())
print('{} training images in {}'.format(
image_list.shape[0], FLAGS.training_data_path))
index = np.arange(0, image_list.shape[0])
while True:
np.random.shuffle(index)
images = []
image_fns = []
score_maps = []
geo_maps = []
training_masks = []
for i in index:
try:
im_fn = image_list[i]
im = cv2.imread(im_fn)
h, w, _ = im.shape
txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt')
if not os.path.exists(txt_fn):
print('text file {} does not exists'.format(txt_fn))
continue
text_polys, text_tags = load_annoataion(txt_fn)
text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w))
rd_scale = np.random.choice(random_scale)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
if np.random.rand() < background_ratio:
im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True)
if text_polys.shape[0] > 0:
continue
new_h, new_w, _ = im.shape
max_h_w_i = np.max([new_h, new_w, input_size])
im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
im_padded[:new_h, :new_w, :] = im.copy()
im = cv2.resize(im_padded, dsize=(input_size, input_size))
score_map = np.zeros((input_size, input_size), dtype=np.uint8)
geo_map_channels = 5 if FLAGS.geometry == 'RBOX' else 8
geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32)
training_mask = np.ones((input_size, input_size), dtype=np.uint8)
else:
im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False)
if text_polys.shape[0] == 0:
continue
h, w, _ = im.shape
new_h, new_w, _ = im.shape
max_h_w_i = np.max([new_h, new_w, input_size])
im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8)
im_padded[:new_h, :new_w, :] = im.copy()
im = im_padded
new_h, new_w, _ = im.shape
resize_h = input_size
resize_w = input_size
im = cv2.resize(im, dsize=(resize_w, resize_h))
resize_ratio_3_x = resize_w/float(new_w)
resize_ratio_3_y = resize_h/float(new_h)
text_polys[:, :, 0] *= resize_ratio_3_x
text_polys[:, :, 1] *= resize_ratio_3_y
new_h, new_w, _ = im.shape
score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags)
if vis:
fig, axs = plt.subplots(3, 2, figsize=(20, 30))
axs[0, 0].imshow(im[:, :, ::-1])
axs[0, 0].set_xticks([])
axs[0, 0].set_yticks([])
for poly in text_polys:
poly_h = min(abs(poly[3, 1] - poly[0, 1]), abs(poly[2, 1] - poly[1, 1]))
poly_w = min(abs(poly[1, 0] - poly[0, 0]), abs(poly[2, 0] - poly[3, 0]))
axs[0, 0].add_artist(Patches.Polygon(
poly, facecolor='none', edgecolor='green', linewidth=2, linestyle='-', fill=True))
axs[0, 0].text(poly[0, 0], poly[0, 1], '{:.0f}-{:.0f}'.format(poly_h, poly_w), color='purple')
axs[0, 1].imshow(score_map[::, ::])
axs[0, 1].set_xticks([])
axs[0, 1].set_yticks([])
axs[1, 0].imshow(geo_map[::, ::, 0])
axs[1, 0].set_xticks([])
axs[1, 0].set_yticks([])
axs[1, 1].imshow(geo_map[::, ::, 1])
axs[1, 1].set_xticks([])
axs[1, 1].set_yticks([])
axs[2, 0].imshow(geo_map[::, ::, 2])
axs[2, 0].set_xticks([])
axs[2, 0].set_yticks([])
axs[2, 1].imshow(training_mask[::, ::])
axs[2, 1].set_xticks([])
axs[2, 1].set_yticks([])
plt.tight_layout()
plt.show()
plt.close()
images.append(im[:, :, ::-1].astype(np.float32))
image_fns.append(im_fn)
score_maps.append(score_map[::4, ::4, np.newaxis].astype(np.float32))
geo_maps.append(geo_map[::4, ::4, :].astype(np.float32))
training_masks.append(training_mask[::4, ::4, np.newaxis].astype(np.float32))
if len(images) == batch_size:
yield images, image_fns, score_maps, geo_maps, training_masks
images = []
image_fns = []
score_maps = []
geo_maps = []
training_masks = []
except Exception as e:
import traceback
traceback.print_exc()
continue
def crop_area(im, polys, tags, crop_background=False, max_tries=50):
'''
make random crop from the input image
:param im:
:param polys:[[[x1, y1], [x2, y2], [x3, y3], [x4, y4]] , ....]
:param tags:
:param crop_background:
:param max_tries:
:return:
'''
h, w, _ = im.shape
pad_h = h//10
pad_w = w//10
h_array = np.zeros((h + pad_h*2), dtype=np.int32)
w_array = np.zeros((w + pad_w*2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx+pad_w:maxx+pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny+pad_h:maxy+pad_h] = 1
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w-1)
xmax = np.clip(xmax, 0, w-1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h-1)
ymax = np.clip(ymax, 0, h-1)
if xmax - xmin < FLAGS.min_crop_side_ratio*w or ymax - ymin < FLAGS.min_crop_side_ratio*h:
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
if crop_background:
return im[ymin:ymax+1, xmin:xmax+1, :], polys[selected_polys], tags[selected_polys]
else:
continue
im = im[ymin:ymax+1, xmin:xmax+1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags
return im, polys, tags
def generate_rbox(im_size, polys, tags):
h, w = im_size
poly_mask = np.zeros((h, w), dtype=np.uint8)
score_map = np.zeros((h, w), dtype=np.uint8)
geo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones((h, w), dtype=np.uint8)
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
r = [None, None, None, None]
for i in range(4):
r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
np.linalg.norm(poly[i] - poly[(i - 1) % 4]))
shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
cv2.fillPoly(score_map, shrinked_poly, 1)
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2]))
poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3]))
if min(poly_h, poly_w) < FLAGS.min_text_size:
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
if tag:
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
fitted_parallelograms = []
for i in range(4):
p0 = poly[i]
p1 = poly[(i + 1) % 4]
p2 = poly[(i + 2) % 4]
p3 = poly[(i + 3) % 4]
edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]])
backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3):
if edge[1] == 0:
edge_opposite = [1, 0, -p2[0]]
else:
edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]]
else:
if edge[1] == 0:
edge_opposite = [1, 0, -p3[0]]
else:
edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]]
new_p0 = p0
new_p1 = p1
new_p2 = p2
new_p3 = p3
new_p2 = line_cross_point(forward_edge, edge_opposite)
if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3):
if forward_edge[1] == 0:
forward_opposite = [1, 0, -p0[0]]
else:
forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]]
else:
if forward_edge[1] == 0:
forward_opposite = [1, 0, -p3[0]]
else:
forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]]
new_p0 = line_cross_point(forward_opposite, edge)
new_p3 = line_cross_point(forward_opposite, edge_opposite)
fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
new_p0 = p0
new_p1 = p1
new_p2 = p2
new_p3 = p3
new_p3 = line_cross_point(backward_edge, edge_opposite)
if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2):
if backward_edge[1] == 0:
backward_opposite = [1, 0, -p1[0]]
else:
backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]]
else:
if backward_edge[1] == 0:
backward_opposite = [1, 0, -p2[0]]
else:
backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]]
new_p1 = line_cross_point(backward_opposite, edge)
new_p2 = line_cross_point(backward_opposite, edge_opposite)
fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
areas = [Polygon(t).area for t in fitted_parallelograms]
parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32)
parallelogram_coord_sum = np.sum(parallelogram, axis=1)
min_coord_idx = np.argmin(parallelogram_coord_sum)
parallelogram = parallelogram[
[min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]]
rectange = rectangle_from_parallelogram(parallelogram)
rectange, rotate_angle = sort_rectangle(rectange)
p0_rect, p1_rect, p2_rect, p3_rect = rectange
for y, x in xy_in_poly:
point = np.array([x, y], dtype=np.float32)
geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point)
geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point)
geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point)
geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point)
geo_map[y, x, 4] = rotate_angle
return score_map, geo_map, training_mask
网络结构搭建、特征图的生成
def model(images, weight_decay=1e-5, is_training=True):
images = mean_image_subtraction(images)
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50')
with tf.variable_scope('feature_fusion', values=[end_points.values]):
batch_norm_params = {
'decay': 0.997,
'epsilon': 1e-5,
'scale': True,
'is_training': is_training
}
with slim.arg_scope([slim.conv2d],
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params,
weights_regularizer=slim.l2_regularizer(weight_decay)):
f = [end_points['pool5'], end_points['pool4'],
end_points['pool3'], end_points['pool2']]
for i in range(4):
print('Shape of f_{} {}'.format(i, f[i].shape))
g = [None, None, None, None]
h = [None, None, None, None]
num_outputs = [None, 128, 64, 32]
for i in range(4):
if i == 0:
h[i] = f[i]
else:
c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1)
h[i] = slim.conv2d(c1_1, num_outputs[i], 3)
if i <= 2:
g[i] = unpool(h[i])
else:
g[i] = slim.conv2d(h[i], num_outputs[i], 3)
print('Shape of h_{} {}, g_{} {}'.format(i, h[i].shape, i, g[i].shape))
F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None)
geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale
angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2
F_geometry = tf.concat([geo_map, angle_map], axis=-1)
return F_score, F_geometry
损失函数
def dice_coefficient(y_true_cls, y_pred_cls,
training_mask):
'''
dice loss
:param y_true_cls:
:param y_pred_cls:
:param training_mask:
:return:
'''
eps = 1e-5
intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask)
union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
tf.summary.scalar('classification_dice_loss', loss)
return loss
def loss(y_true_cls, y_pred_cls,
y_true_geo, y_pred_geo,
training_mask):
'''
define the loss used for training, contraning two part,
the first part we use dice loss instead of weighted logloss,
the second part is the iou loss defined in the paper
:param y_true_cls: ground truth of text
:param y_pred_cls: prediction os text
:param y_true_geo: ground truth of geometry
:param y_pred_geo: prediction of geometry
:param training_mask: mask used in training, to ignore some text annotated by ###
:return:
'''
classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
classification_loss *= 0.01
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred)
h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0))
L_theta = 1 - tf.cos(theta_pred - theta_gt)
tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask))
tf.summary.scalar('geometry_theta', tf.reduce_mean(L_theta * y_true_cls * training_mask))
L_g = L_AABB + 20 * L_theta
return tf.reduce_mean(L_g * y_true_cls * training_mask) + classification_loss
仿真实验方案设计
//待填坑
CRNN简介
与文本检测算法的发展相比,文本识别算法的发展较为缓慢。以前的文本识别过程一般分为两步:分割字符以及字符识别。本文选择以Shi B等人提出的卷积循环神经网络算法(CRNN)作为文本识别实验的基准算法,该算法可以省去分割字符的步骤,将文本识别的问题转化为基于图像的序列识别问题,可以直接对不定长的文本区域图像中的文本内容进行识别。 CRNN算法的网络结构整体上可以分为三个部分:卷积层、循环层和转录层,算法的运行流程如图4-1所示。 (1) 卷积层。 在卷积层部分使用深度卷积神经网络来提取输入的图像的特征信息,得到一系列特征图。具体采用的是类似VGG16网络的结构。其中需要注意后两个最大池化的窗口由普通的2×2改为1×2,这是因为文本区域多为高比宽大的矩形,得到的特征图也是相同的形状,使用1×2的窗口形状会减少在高度方向的信息丢失,对于英文字母的识别的效果有一定提升。其中的BatchNormalization层用于加快训练模型时的收敛速度。输入图片是单通道的灰度图片,图片高度固定为32,宽度为160。当输入图片的尺寸为1×32×160时,卷积层的输出尺寸为512×1×16。但是此时得到的特征图还不能送入循环层,需要进一步转换(即Map-to-Sequence),将特征图转化为一组特征向量序列,过程如图4.2所示。 将卷积神经网络输出的特征图转换为特征向量的过程如下:在特征图上按从左到右的顺序,依次将每一列的像素组合成一个向量,然后将这些向量按照从左到右的顺序排列,就得到了一组特征向量序列。如果按照上面的图片的输入尺寸,则得到16个长度为512的特征向量组成的序列。 (2) 循环层。 循环层中的循环神经网络用于对序列数据进行预测,过程如图4-3所示。 如果用x=x_l,〖…x〗_t,…x_T表示特征序列,T代表序列的总长度。循环神经网络对每一个x_t都预测一个标签概率y_t。在CRNN算法的文本识别实验中,标签包含英文字母、阿拉伯数字和一个特殊的空白符号(“-”)。RNN可以捕获上下文中的信息,这样比单独对单个字符进行识别更有效。RNN可以与前面的卷积层中的卷积神经网络串联起来共用同一个损失函数。而且RNN可以处理任意长度的序列数据。 但传统的RNN存在梯度消失等问题,因此实际中一般使用LSTM网络或者GRU网络。但是LSTM网络是只能利用现在和之前的信息,却不能利用之后的信息,即它只能学习序列中单方向的关联特征信息,而在文本识别的任务中,一个字符不但与之前的字符有关系,还与之后的字符有关系。为了充分利用文本内容中上下文的信息,可以将一个学习正方向关联特征信息的LSTM网络与一个学习反方向关联特征信息的LSTM网络进行组合,得到一个双向的LSTM网络。双向LSTM网络也可以像卷积层一样堆叠使用来加深网络,在本文的算法中堆叠使用了两个双向LSTM网络。 在循环层中,输入卷积层输出的特征向量序列,对每一个特征向量,双向LSTM都输出一个其对应所有字符的概率分布向量,再用Softmax函数将其中的值统一到[0,1]的区间内。输出由16个向量组成的后验概率矩阵,每个向量的记录该向量表示的特征图上的对应区域是此字符的概率值,其中单个向量的长度即使字符类别的总数量,以上过程如图4.4所示。 (3) 转录层CTC Loss 转录层的作用是将循环层输出的概率矩阵中的每一个向量转换为标签序列,即通过概率后验矩阵找到概率最高的标签序列。此时得到的标签序列存在的问题是:由于某些字符可能映射到了了多个向量,以及字符之间存在空格,一次,例如“ab”可能对应的输出为“aa-b”,如图4.5所示。 然后再进行字符对齐,字符对齐的过程是:先将字符之间没有“-”的相同字符合并为一个,再删除“-”。例如,“ffooo-oott”对应为“foot”。
损失函数
算法代码
import torch.nn as nn
from collections import OrderedDict
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec)
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1)
self.relu1 = nn.ReLU(True)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
self.relu2 = nn.ReLU(True)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
self.bn3 = nn.BatchNorm2d(256)
self.relu3_1 = nn.ReLU(True)
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
self.relu3_2 = nn.ReLU(True)
self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
self.bn4 = nn.BatchNorm2d(512)
self.relu4_1 = nn.ReLU(True)
self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
self.relu4_2 = nn.ReLU(True)
self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
self.conv5 = nn.Conv2d(512, 512, 2, 1, 0)
self.bn5 = nn.BatchNorm2d(512)
self.relu5 = nn.ReLU(True)
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
def forward(self, input):
x = self.pool1(self.relu1(self.conv1(input)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.pool3(self.relu3_2(self.conv3_2(self.relu3_1(self.bn3(self.conv3_1
(x))))))
x = self.pool4(self.relu4_2(self.conv4_2(self.relu4_1(self.bn4(self.conv4_1
(x))))))
conv = self.relu5(self.bn5(self.conv5(x)))
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1)
output = self.rnn(conv)
return output
class CRNN_v2(nn.Module):
def __init__(self, imgH, nc, nclass, nh, leakyRelu=False):
super(CRNN_v2, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
self.conv1_1 = nn.Conv2d(nc, 32, 3, 1, 1)
self.bn1_1 = nn.BatchNorm2d(32)
self.relu1_1 = nn.ReLU(True)
self.conv1_2 = nn.Conv2d(32, 64, 3, 1, 1)
self.bn1_2 = nn.BatchNorm2d(64)
self.relu1_2 = nn.ReLU(True)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2_1 = nn.Conv2d(64, 64, 3, 1, 1)
self.bn2_1 = nn.BatchNorm2d(64)
self.relu2_1 = nn.ReLU(True)
self.conv2_2 = nn.Conv2d(64, 128, 3, 1, 1)
self.bn2_2 = nn.BatchNorm2d(128)
self.relu2_2 = nn.ReLU(True)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3_1 = nn.Conv2d(128, 96, 3, 1, 1)
self.bn3_1 = nn.BatchNorm2d(96)
self.relu3_1 = nn.ReLU(True)
self.conv3_2 = nn.Conv2d(96, 192, 3, 1, 1)
self.bn3_2 = nn.BatchNorm2d(192)
self.relu3_2 = nn.ReLU(True)
self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
self.conv4_1 = nn.Conv2d(192, 128, 3, 1, 1)
self.bn4_1 = nn.BatchNorm2d(128)
self.relu4_1 = nn.ReLU(True)
self.conv4_2 = nn.Conv2d(128, 256, 3, 1, 1)
self.bn4_2 = nn.BatchNorm2d(256)
self.relu4_2 = nn.ReLU(True)
self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
self.bn5 = nn.BatchNorm2d(256)
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
def forward(self, input):
x = self.pool1(self.relu1_2(self.bn1_2(self.conv1_2(self.relu1_1(self.bn1_1(self.conv1_1(input)))))))
x = self.pool2(self.relu2_2(self.bn2_2(self.conv2_2(self.relu2_1(self.bn2_1(self.conv2_1(x)))))))
x = self.pool3(self.relu3_2(self.bn3_2(self.conv3_2(self.relu3_1(self.bn3_1(self.conv3_1(x)))))))
x = self.pool4(self.relu4_2(self.bn4_2(self.conv4_2(self.relu4_1(self.bn4_1(self.conv4_1(x)))))))
conv = self.bn5(x)
b, c, h, w = conv.size()
assert h == 2, "the height of conv must be 2"
conv = conv.reshape([b,c*h,w])
conv = conv.permute(2, 0, 1)
output = self.rnn(conv)
return output
def conv3x3(nIn, nOut, stride=1):
return nn.Conv2d( nIn, nOut, kernel_size=3, stride=stride, padding=1, bias=False )
class basic_res_block(nn.Module):
def __init__(self, nIn, nOut, stride=1, downsample=None):
super( basic_res_block, self ).__init__()
m = OrderedDict()
m['conv1'] = conv3x3( nIn, nOut, stride )
m['bn1'] = nn.BatchNorm2d( nOut )
m['relu1'] = nn.ReLU( inplace=True )
m['conv2'] = conv3x3( nOut, nOut )
m['bn2'] = nn.BatchNorm2d( nOut )
self.group1 = nn.Sequential( m )
self.relu = nn.Sequential( nn.ReLU( inplace=True ) )
self.downsample = downsample
def forward(self, x):
if self.downsample is not None:
residual = self.downsample( x )
else:
residual = x
out = self.group1( x ) + residual
out = self.relu( out )
return out
class CRNN_res(nn.Module):
def __init__(self, imgH, nc, nclass, nh):
super(CRNN_res, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1)
self.relu1 = nn.ReLU(True)
self.res1 = basic_res_block(64, 64)
down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128))
self.res2_1 = basic_res_block( 64, 128, 2, down1 )
self.res2_2 = basic_res_block(128,128)
down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256))
self.res3_1 = basic_res_block(128, 256, 2, down2)
self.res3_2 = basic_res_block(256, 256)
self.res3_3 = basic_res_block(256, 256)
down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=(2, 1), bias=False),nn.BatchNorm2d(512))
self.res4_1 = basic_res_block(256, 512, (2, 1), down3)
self.res4_2 = basic_res_block(512, 512)
self.res4_3 = basic_res_block(512, 512)
self.pool = nn.AvgPool2d((2, 2), (2, 1), (0, 1))
self.conv5 = nn.Conv2d(512, 512, 2, 1, 0)
self.bn5 = nn.BatchNorm2d(512)
self.relu5 = nn.ReLU(True)
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
def forward(self, input):
x = self.res1(self.relu1(self.conv1(input)))
x = self.res2_2(self.res2_1(x))
x = self.res3_3(self.res3_2(self.res3_1(x)))
x = self.res4_3(self.res4_2(self.res4_1(x)))
x = self.pool(x)
conv = self.relu5(self.bn5(self.conv5(x)))
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1)
output = self.rnn(conv)
return output
if __name__ == '__main__':
pass
仿真实验方案设计
//待填坑
|