本文是跟着datawhale组队学习学的,原文在这:动手学CV-Pytorch 6.2_使用transformer实现OCR字符识别
1.数据集相关操作
import os
import cv2
base_data_dir = './ICDAR_2015'
train_img_dir = os.path.join(base_data_dir, 'train')
valid_img_dir = os.path.join(base_data_dir, 'valid')
train_lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
valid_lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')
1.1标签最长字符个数统计
def statistics_max_len_label(lbl_path):
"""
统计标签文件中最长的label所包含的字符数
lbl_path:txt标签文件路径
"""
max_len = -1
with open(lbl_path, 'r',encoding = 'utf-8') as reader:
for line in reader:
items = line.rstrip().split(',')
lbl_str = items[1].strip()[1:-1]
lbl_len = len(lbl_str)
max_len = max_len if max_len>lbl_len else lbl_len
return max_len
train_max_label_len = statistics_max_len_label(train_lbl_path)
valid_max_label_len = statistics_max_len_label(valid_lbl_path)
max_label_len = max(train_max_label_len, valid_max_label_len)
print(f"数据集中包含字符最多的label长度为{max_label_len}")
数据集中包含字符最多的label长度为21
def statistics_label_cnt(lbl_path, lbl_cnt_map):
"""
统计标签文件中label都包含了哪些字符以及各自出现的次数
lbl_path:标签所处路径
lbl_cnt_map:记录标签中字符出现次数的字典
"""
with open(lbl_path, 'r',encoding = 'utf-8') as reader:
for line in reader:
items = line.rstrip().split(',')
lbl_str = items[1].strip()[1:-1]
for lbl in lbl_str:
if lbl not in lbl_cnt_map.keys():
lbl_cnt_map[lbl] = 1
else:
lbl_cnt_map[lbl] +=1
lbl_cnt_map = dict()
statistics_label_cnt(train_lbl_path, lbl_cnt_map)
print("训练集中label中出现的字符:")
print(lbl_cnt_map)
statistics_label_cnt(valid_lbl_path, lbl_cnt_map)
print("训练集+验证集label中出现的字符:")
print(lbl_cnt_map)
训练集中label中出现的字符:
{'[': 2, '0': 182, '6': 38, ']': 2, '2': 119, '-': 68, '3': 50, 'C': 593, 'a': 843, 'r': 655, 'p': 197, 'k': 96, 'E': 1421, 'X': 110, 'I': 861, 'T': 896, 'R': 836, 'f': 133, 'u': 293, 's': 557, 'i': 651, 'o': 659, 'n': 605, 'l': 408, 'e': 1055, 'v': 123, 'A': 1189, 'U': 319, 'O': 965, 'N': 785, 'c': 318, 't': 563, 'm': 202, 'W': 179, 'H': 391, 'Y': 229, 'P': 389, 'F': 259, 'G': 345, '?': 5, 'S': 1161, 'b': 88, 'h': 299, ' ': 50, 'g': 171, 'L': 745, 'M': 367, 'D': 383, 'd': 257, '$': 46, '5': 77, '4': 44, '.': 95, 'w': 97, 'B': 331, '1': 184, '7': 43, '8': 44, 'V': 158, 'y': 161, 'K': 163, '!': 51, '9': 66, 'z': 12, ';': 3, '#': 16, 'j': 15, "'": 51, 'J': 72, ':': 19, 'x': 27, '%': 28, '/': 24, 'q': 3, 'Q': 19, '(': 6, ')': 5, '\\': 8, '"': 8, '′': 3, 'Z': 29, '&': 9, 'é': 1, '@': 4, '=': 1, '+': 1}
训练集+验证集label中出现的字符:
{'[': 2, '0': 232, '6': 44, ']': 2, '2': 139, '-': 87, '3': 69, 'C': 893, 'a': 1200, 'r': 935, 'p': 317, 'k': 137, 'E': 2213, 'X': 181, 'I': 1241, 'T': 1315, 'R': 1262, 'f': 203, 'u': 415, 's': 793, 'i': 924, 'o': 954, 'n': 880, 'l': 555, 'e': 1534, 'v': 169, 'A': 1827, 'U': 467, 'O': 1440, 'N': 1158, 'c': 442, 't': 829, 'm': 278, 'W': 288, 'H': 593, 'Y': 341, 'P': 582, 'F': 402, 'G': 521, '?': 7, 'S': 1748, 'b': 129, 'h': 417, ' ': 82, 'g': 260, 'L': 1120, 'M': 536, 'D': 548, 'd': 367, '$': 57, '5': 100, '4': 53, '.': 132, 'w': 136, 'B': 468, '1': 228, '7': 60, '8': 51, 'V': 224, 'y': 231, 'K': 253, '!': 65, '9': 76, 'z': 14, ';': 3, '#': 24, 'j': 19, "'": 70, 'J': 100, ':': 24, 'x': 38, '%': 42, '/': 29, 'q': 3, 'Q': 28, '(': 7, ')': 5, '\\': 8, '"': 8, '′': 3, 'Z': 36, '&': 15, 'é': 2, '@': 9, '=': 1, '+': 2, 'é': 1}
上方代码中,lbl_cnt_map 为字符出现次数的统计字典,后面还会用于建立字符及其id映射关系。从数据集统计结果来看,测试集含有训练集没有出现过的字符,例如测试集中包含1个’é’未曾在训练集出现。这种情况数量不多,应该问题不大,所以此处未对数据集进行额外处理(但是有意识的进行这种训练集和测试集是否存在diff的检查是必要的)。
1.2char和id的映射字典构建
print("构造label中 字符--id之间的映射:")
lbl2id_map = dict()
lbl2id_map['?'] = 0
lbl2id_map['■'] = 1
lbl2id_map['□'] = 2
cur_id = 3
for lbl in lbl_cnt_map.keys():
lbl2id_map[lbl] = cur_id
cur_id += 1
with open(lbl2id_map_path, 'w', encoding='utf-8') as writer:
for lbl in lbl2id_map.keys():
cur_id = lbl2id_map[lbl]
print (lbl, cur_id)
line = lbl + '\t' + str(cur_id) + '\n'
writer.write(line)
构造label中 字符--id之间的映射:
? 0
■ 1
□ 2
[ 3
0 4
6 5
] 6
2 7
- 8
3 9
C 10
a 11
r 12
p 13
k 14
E 15
X 16
I 17
T 18
R 19
f 20
u 21
s 22
i 23
o 24
n 25
l 26
e 27
v 28
A 29
U 30
O 31
N 32
c 33
t 34
m 35
W 36
H 37
Y 38
P 39
F 40
G 41
? 42
S 43
b 44
h 45
46
g 47
L 48
M 49
D 50
d 51
$ 52
5 53
4 54
. 55
w 56
B 57
1 58
7 59
8 60
V 61
y 62
K 63
! 64
9 65
z 66
; 67
# 68
j 69
' 70
J 71
: 72
x 73
% 74
/ 75
q 76
Q 77
( 78
) 79
\ 80
" 81
′ 82
Z 83
& 84
é 85
@ 86
= 87
+ 88
é 89
def load_lbl2id_map(lbl2id_map_path):
"""
读取 字符-id 映射关系记录的txt文件,并返回 lbl->id 和 id->lbl 映射字典
lbl2id_map_path : 字符-id 映射关系记录的txt文件路径
"""
lbl2id_map = dict()
id2lbl_map = dict()
with open(lbl2id_map_path, 'r',encoding = 'utf-8') as reader:
for line in reader:
items = line.rstrip().split('\t')
label = items[0]
cur_id = int(items[1])
lbl2id_map[label] = cur_id
id2lbl_map[cur_id] = label
return lbl2id_map, id2lbl_map
1.3数据集图像尺寸分析
print("分析数据集图片尺寸:")
min_h = 1e10
min_w = 1e10
max_h = -1
max_w = -1
min_ratio = 1e10
max_ratio = 0
for img_name in os.listdir(train_img_dir):
img_path = os.path.join(train_img_dir,img_name)
img = cv2.imread(img_path)
h, w = img.shape[:2]
ratio = w / h
min_h = min_h if min_h <= h else h
max_h = max_h if max_h >= h else h
min_w = min_w if min_w <= w else w
max_w = max_w if max_w >= w else w
min_ratio = min_ratio if min_ratio <= ratio else ratio
max_ratio = max_ratio if max_ratio >= ratio else ratio
print('min_h:', min_h)
print('max_h:', max_h)
print('min_w:', min_w)
print('max_w:', max_w)
print('min_ratio:', min_ratio)
print('max_ratio:', max_ratio)
分析数据集图片尺寸:
min_h: 9
max_h: 295
min_w: 16
max_w: 628
min_ratio: 0.6666666666666666
max_ratio: 8.619047619047619
2.将transformer引入OCR
2.1准备工作
import os
import time
import copy
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms
from transformer import *
from train_utils import *
base_data_dir = './ICDAR_2015/'
device = torch.device('cuda')
nrof_epochs = 1500
batch_size = 16
model_save_path = './log/ex1_ocr_model.pth'
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')
lbl2id_map, id2lbl_map = load_lbl2id_map(lbl2id_map_path)
train_lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
valid_lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
train_max_label_len = statistics_max_len_label(train_lbl_path)
valid_max_label_len = statistics_max_len_label(valid_lbl_path)
sequence_len = max(train_max_label_len, valid_max_label_len)
2.2数据集创建
class Recognition_Dataset(object):
def __init__(self, dataset_root_dir, lbl2id_map, sequence_len, max_ratio, phase='train', pad=0):
if phase == 'train':
self.img_dir = os.path.join(base_data_dir, 'train')
self.lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
else:
self.img_dir = os.path.join(base_data_dir, 'valid')
self.lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
self.lbl2id_map = lbl2id_map
self.pad = pad
self.sequence_len = sequence_len
self.max_ratio = max_ratio * 3
self.imgs_list = []
self.lbls_list = []
with open(self.lbl_path, 'r',encoding = 'utf-8') as reader:
for line in reader:
items = line.rstrip().split(',')
img_name = items[0]
lbl_str = items[1].strip()[1:-1]
self.imgs_list.append(img_name)
self.lbls_list.append(lbl_str)
self.color_trans = transforms.ColorJitter(0.1, 0.1, 0.1)
self.trans_Normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
])
def __getitem__(self, index):
"""
获取对应index的图像和ground truth label,并视情况进行数据增强
"""
img_name = self.imgs_list[index]
img_path = os.path.join(self.img_dir, img_name)
lbl_str = self.lbls_list[index]
img = Image.open(img_path).convert('RGB')
w, h = img.size
ratio = round((w / h) * 3)
if ratio == 0:
ratio = 1
if ratio > self.max_ratio:
ratio = self.max_ratio
h_new = 32
w_new = h_new * ratio
img_resize = img.resize((w_new, h_new), Image.BILINEAR)
img_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))
img_padd.paste(img_resize, (0, 0))
img_input = self.color_trans(img_padd)
img_input = self.trans_Normalize(img_input)
encode_mask = [1] * ratio + [0] * (self.max_ratio - ratio)
encode_mask = torch.tensor(encode_mask)
encode_mask = (encode_mask != 0).unsqueeze(0)
gt = []
gt.append(1)
for lbl in lbl_str:
gt.append(self.lbl2id_map[lbl])
gt.append(2)
for i in range(len(lbl_str), self.sequence_len):
gt.append(0)
gt = gt[:self.sequence_len]
decode_in = gt[:-1]
decode_in = torch.tensor(decode_in)
decode_out = gt[1:]
decode_out = torch.tensor(decode_out)
decode_mask = self.make_std_mask(decode_in, self.pad)
ntokens = (decode_out != self.pad).data.sum()
return img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens
@staticmethod
def make_std_mask(tgt, pad):
"""
Create a mask to hide padding and future words.
padd 和 future words 均在mask中用0表示
"""
tgt_mask = (tgt != pad)
tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
tgt_mask = tgt_mask.squeeze(0)
return tgt_mask
def __len__(self):
return len(self.imgs_list)
以上是构建Dataset的所有细节,进而我们可以构建出DataLoader供训练使用
max_ratio = 8
train_dataset = Recognition_Dataset(base_data_dir, lbl2id_map, sequence_len, max_ratio, 'train', pad=0)
valid_dataset = Recognition_Dataset(base_data_dir, lbl2id_map, sequence_len, max_ratio, 'valid', pad=0)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0)
3.模型构建
代码通过 make_ocr_model 和 OCR_EncoderDecoder 类完成模型结构搭建。
从 make_ocr_model 这个函数看起,该函数首先调用了pytorch中预训练的Resnet-18作为backbone以提取图像特征,此处也可以根据自己需要调整为其他的网络,但需要重点关注的是网络的下采样倍数,以及最后一层特征图的channel_num,相关模块的参数需要同步调整。之后调用了OCR_EncoderDecoder 类完成transformer的搭建。最后对模型参数进行初始化。
OCR_EncoderDecoder 类中,该类相当于是一个transformer各基础组件的拼装线,包括 encoder和 decoder 等,其初始参数是已存在的基本组件,其基本组件代码都在transformer.py 文件中,本文不过多赘述。
来回顾一下,图片经过backbone后,如何构造为Transformer的输入: 图片经过backbone后将输出一个维度为 [batch_size, 512, 1, 24] 的特征图,在不关注batch_size的前提下,每一张图像都会得到如下所示具有512个通道的1×24的特征图,如图中红色框标注所示,将不同通道相同位置的特征值拼接组成一个新的向量,并作为一个时间步的输入,此时变构造出了维度为[batch_size, 24, 512] 的输入,满足Transformer的输入要求。 下面看完整的构造模型部分的代码:
class OCR_EncoderDecoder(nn.Module):
"""
A standard Encoder-Decoder architecture.
Base for this and many other models.
"""
def __init__(self, encoder, decoder, src_embed, src_position, tgt_embed, generator):
super(OCR_EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.src_position = src_position
self.tgt_embed = tgt_embed
self.generator = generator
def forward(self, src, tgt, src_mask, tgt_mask):
"Take in and process masked src and target sequences."
memory = self.encode(src, src_mask)
res = self.decode(memory, src_mask, tgt, tgt_mask)
return res
def encode(self, src, src_mask):
src_embedds = self.src_embed(src)
src_embedds = src_embedds.squeeze(-2)
src_embedds = src_embedds.permute(0, 2, 1)
src_embedds = self.src_position(src_embedds)
return self.encoder(src_embedds, src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
target_embedds = self.tgt_embed(tgt)
return self.decoder(target_embedds, memory, src_mask, tgt_mask)
def make_ocr_model(tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
"""
构建模型
params:
tgt_vocab: 输出的词典大小(82)
N: 编码器和解码器堆叠基础模块的个数
d_model: 模型中embedding的size,默认512
d_ff: FeedForward Layer层中embedding的size,默认2048
h: MultiHeadAttention中多头的个数,必须被d_model整除
dropout:
"""
c = copy.deepcopy
backbone = models.resnet18(pretrained=True)
backbone = nn.Sequential(*list(backbone.children())[:-2])
attn = MultiHeadedAttention(h, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
model = OCR_EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
backbone,
c(position),
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
Generator(d_model, tgt_vocab))
for child in model.children():
if child is backbone:
for param in child.parameters():
param.requires_grad = False
continue
for p in child.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
构建transformer模型:
tgt_vocab = len(lbl2id_map.keys())
d_model = 512
ocr_model = make_ocr_model(tgt_vocab, N=5, d_model=d_model, d_ff=2048, h=8, dropout=0.1)
ocr_model.to(device)
4.模型训练
模型训练之前,还需要定义模型评判准则、迭代优化器等。本实验在训练时,使用了标签平滑(label smoothing)、网络训练热身(warmup)等策略,以上策略的调用函数均在train_utils.py 文件中,此处不涉及以上两种方法的原理及代码实现。
label smoothing可以将原始的硬标签转化为软标签,从而增加模型的容错率,提升模型泛化能力。代码中 LabelSmoothing() 函数实现了label smoothing,同时内部使用了相对熵函数计算了预测值与真实值之间的损失。
warmup策略能够有效控制模型训练过程中的优化器学习率,自动化的实现模型学习率由小增大再逐渐下降的控制,帮助模型在训练时更加稳定,实现损失的快速收敛。代码中 NoamOpt() 函数实现了warmup控制,采用的Adam优化器,实现学习率随迭代次数的自动调整。
criterion = LabelSmoothing(size=tgt_vocab, padding_idx=0, smoothing=0.0)
optimizer = torch.optim.Adam(ocr_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
model_opt = NoamOpt(d_model, 1, 400, optimizer)
SimpleLossCompute() 类实现了transformer输出结果的loss计算。在使用该类直接计算时,类需要接收(x, y, norm) 三个参数, x 为decoder输出的结果, y 为标签数据, norm 为loss的归一化系数,用batch中所有有效token数即可。由此可见,此处才正完成transformer所有网络的构建,实现数据计算 流的流通。
class SimpleLossCompute:
"A simple loss compute and train function."
def __init__(self, generator, criterion, opt=None):
self.generator = generator
self.criterion = criterion
self.opt = opt
def __call__(self, x, y, norm):
"""
norm: loss的归一化系数,用batch中所有有效token数即可
"""
x = self.generator(x)
x_ = x.contiguous().view(-1, x.size(-1))
y_ = y.contiguous().view(-1)
loss = self.criterion(x_, y_)
loss /= norm
loss.backward()
if self.opt is not None:
self.opt.step()
self.opt.optimizer.zero_grad()
return loss.item() * norm
模型训练过程的代码如下所示,每训练10个epoch便进行一次验证,单个epoch的计算过程封装在run_epoch() 函数中。
def run_epoch(data_loader, model, loss_compute, device=None):
"Standard Training and Logging Function"
start = time.time()
total_tokens = 0
total_loss = 0
tokens = 0
for i, batch in enumerate(data_loader):
img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
img_input = img_input.to(device)
encode_mask = encode_mask.to(device)
decode_in = decode_in.to(device)
decode_out = decode_out.to(device)
decode_mask = decode_mask.to(device)
ntokens = torch.sum(ntokens).to(device)
out = model.forward(img_input, decode_in, encode_mask, decode_mask)
loss = loss_compute(out, decode_out, ntokens)
total_loss += loss
total_tokens += ntokens
tokens += ntokens
if i % 50 == 1:
elapsed = time.time() - start
print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
(i, loss / ntokens, tokens / elapsed))
start = time.time()
tokens = 0
return total_loss / total_tokens
for epoch in range(nrof_epochs):
print(f"\nepoch {epoch}")
print("train...")
ocr_model.train()
loss_compute = SimpleLossCompute(ocr_model.generator, criterion, model_opt)
train_mean_loss = run_epoch(train_loader, ocr_model, loss_compute, device)
if epoch % 10 == 0:
print("valid...")
ocr_model.eval()
valid_loss_compute = SimpleLossCompute(ocr_model.generator, criterion, None)
valid_mean_loss = run_epoch(valid_loader, ocr_model, valid_loss_compute, device)
print(f"valid loss: {valid_mean_loss}")
epoch 0
train...
Epoch Step: 1 Loss: 4.756953 Tokens per Sec: 74.231010
Epoch Step: 51 Loss: 3.345229 Tokens per Sec: 227.936249
Epoch Step: 101 Loss: 3.164185 Tokens per Sec: 217.443558
Epoch Step: 151 Loss: 2.884049 Tokens per Sec: 198.306046
Epoch Step: 201 Loss: 2.918671 Tokens per Sec: 204.400925
Epoch Step: 251 Loss: 3.152167 Tokens per Sec: 205.873077
valid...
Epoch Step: 1 Loss: 2.701383 Tokens per Sec: 275.244385
Epoch Step: 51 Loss: 2.951396 Tokens per Sec: 240.986679
Epoch Step: 101 Loss: 2.714810 Tokens per Sec: 261.232330
valid loss: 2.839085102081299
epoch 1
train...
Epoch Step: 1 Loss: 3.549314 Tokens per Sec: 193.934494
Epoch Step: 51 Loss: 2.953091 Tokens per Sec: 198.670242
Epoch Step: 101 Loss: 2.828863 Tokens per Sec: 214.964783
Epoch Step: 151 Loss: 2.756577 Tokens per Sec: 208.429001
5.贪心解码
我们使用最简单的贪心解码直接进行OCR结果预测。因为模型每一次只会产生一个输出,我们选择输出的概率分布中的最高概率对应的字符为本次预测的结果,然后预测下一个字符,这就是所谓的贪心解码,见代码中 greedy_decode() 函数。 实验中分别将每一张图像作为模型的输入,逐张进行贪心解码统计正确率,并最终给出了训练集和验证集各自的预测准确率。
ocr_model.eval()
print("\n------------------------------------------------")
print("greedy decode trainset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(train_loader):
img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
img_input = img_input.to(device)
encode_mask = encode_mask.to(device)
bs = img_input.shape[0]
for i in range(bs):
cur_img_input = img_input[i].unsqueeze(0)
cur_encode_mask = encode_mask[i].unsqueeze(0)
cur_decode_out = decode_out[i]
pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
pred_result = pred_result.cpu()
is_correct = judge_is_correct(pred_result, cur_decode_out)
total_correct_num += is_correct
total_img_num += 1
if not is_correct:
print("----")
print(cur_decode_out)
print(pred_result)
total_correct_rate = total_correct_num / total_img_num * 100
print(f"total correct rate of trainset: {total_correct_rate}%")
print("\n------------------------------------------------")
print("greedy decode validset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(valid_loader):
img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
img_input = img_input.to(device)
encode_mask = encode_mask.to(device)
bs = img_input.shape[0]
for i in range(bs):
cur_img_input = img_input[i].unsqueeze(0)
cur_encode_mask = encode_mask[i].unsqueeze(0)
cur_decode_out = decode_out[i]
pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
pred_result = pred_result.cpu()
is_correct = judge_is_correct(pred_result, cur_decode_out)
total_correct_num += is_correct
total_img_num += 1
if not is_correct:
print("----")
print(cur_decode_out)
print(pred_result)
total_correct_rate = total_correct_num / total_img_num * 100
print(f"total correct rate of validset: {total_correct_rate}%")
greedy_decode() 函数实现如下:
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data).long()
for i in range(max_len-1):
out = model.decode(memory, src_mask,
Variable(ys),
Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.data[0]
next_word = torch.ones(1, 1).type_as(src.data).fill_(next_word).long()
ys = torch.cat([ys, next_word], dim=1)
next_word = int(next_word)
if next_word == end_symbol:
break
ys = ys[0, 1:]
return ys
def judge_is_correct(pred, label):
pred_len = pred.shape[0]
label = label[:pred_len]
is_correct = 1 if label.equal(pred) else 0
return is_correct
Epoch Step: 1 Loss: 5.315293 Tokens per Sec: 2073.354492
valid...
Epoch Step: 1 Loss: 3.870697 Tokens per Sec: 2173.835449
valid loss: 3.8293662071228027
epoch 1
train...
Epoch Step: 1 Loss: 3.892932 Tokens per Sec: 2160.098633
epoch 2
train...
Epoch Step: 1 Loss: 3.594534 Tokens per Sec: 2163.552490
tensor([56, 56, 56, 55, 62, 56, 47, 12, 24, 21, 13, 55, 33, 24, 35, 55, 22, 47,
2, 0, 0, 0])
tensor([56, 56, 56, 55, 62, 56, 47, 12, 24, 21, 13, 55, 33, 24, 35, 55, 33, 24,
34, 55])
tensor([15, 16, 10, 15, 39, 18, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0])
tensor([18, 31, 39, 2])
total correct rate of validset: 95.78313253012048%
6.总结
本文首先介绍了所使用的ICDAR2015中的一个单词识别任务数据集,然后对数据的特点进行了简单分析,并构建了识别用的字符映射关系表。之后,重点介绍了将transformer引入来解决OCR任务的动机与思路,并结合代码详细介绍了细节,最后大致过了一些训练相关的逻辑和代码。
|