IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【深度学习】使用transformer实现OCR字符识别 -> 正文阅读

[人工智能]【深度学习】使用transformer实现OCR字符识别

本文是跟着datawhale组队学习学的,原文在这:动手学CV-Pytorch 6.2_使用transformer实现OCR字符识别

1.数据集相关操作

#! pip install opencv-python
import os
import cv2
# 数据集根目录,请将数据下载到此位置
base_data_dir = './ICDAR_2015'
# 训练数据集和验证数据集所在路径
train_img_dir = os.path.join(base_data_dir, 'train')
valid_img_dir = os.path.join(base_data_dir, 'valid')
# 训练集和验证集标签文件路径
train_lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
valid_lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
# 中间文件存储路径,存储标签字符与其id的映射关系
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')

1.1标签最长字符个数统计

def statistics_max_len_label(lbl_path):
    """
    统计标签文件中最长的label所包含的字符数
    lbl_path:txt标签文件路径
    """
    max_len = -1
    with open(lbl_path, 'r',encoding = 'utf-8') as reader:
        for line in reader:
            items = line.rstrip().split(',')
            #img_name = item[0] #提取图像名称
            lbl_str = items[1].strip()[1:-1]#提取标签,去除标签中的引号
            lbl_len = len(lbl_str)
            max_len = max_len if max_len>lbl_len else lbl_len
    return max_len
train_max_label_len = statistics_max_len_label(train_lbl_path) # 训练集最长label
valid_max_label_len = statistics_max_len_label(valid_lbl_path) # 验证集最长label
max_label_len = max(train_max_label_len, valid_max_label_len) # 全数据集最长label
print(f"数据集中包含字符最多的label长度为{max_label_len}")
数据集中包含字符最多的label长度为21
def statistics_label_cnt(lbl_path, lbl_cnt_map):
    """
    统计标签文件中label都包含了哪些字符以及各自出现的次数
    lbl_path:标签所处路径
    lbl_cnt_map:记录标签中字符出现次数的字典
    """
    
    with open(lbl_path, 'r',encoding = 'utf-8') as reader:
        for line in reader:
            items = line.rstrip().split(',')
            #img_name = item[0] #提取图像名称
            lbl_str = items[1].strip()[1:-1]#提取标签,去除标签中的引号
            for lbl in lbl_str:
                if lbl not in lbl_cnt_map.keys():
                    lbl_cnt_map[lbl] = 1
                else:
                    lbl_cnt_map[lbl] +=1

                    
lbl_cnt_map = dict() # 用于存储字符出现次数的字典
statistics_label_cnt(train_lbl_path, lbl_cnt_map) # 训练集中字符出现次数统计
print("训练集中label中出现的字符:")
print(lbl_cnt_map)
statistics_label_cnt(valid_lbl_path, lbl_cnt_map) # 训练集和验证集中字符出现次数统计
print("训练集+验证集label中出现的字符:")
print(lbl_cnt_map)
训练集中label中出现的字符:
{'[': 2, '0': 182, '6': 38, ']': 2, '2': 119, '-': 68, '3': 50, 'C': 593, 'a': 843, 'r': 655, 'p': 197, 'k': 96, 'E': 1421, 'X': 110, 'I': 861, 'T': 896, 'R': 836, 'f': 133, 'u': 293, 's': 557, 'i': 651, 'o': 659, 'n': 605, 'l': 408, 'e': 1055, 'v': 123, 'A': 1189, 'U': 319, 'O': 965, 'N': 785, 'c': 318, 't': 563, 'm': 202, 'W': 179, 'H': 391, 'Y': 229, 'P': 389, 'F': 259, 'G': 345, '?': 5, 'S': 1161, 'b': 88, 'h': 299, ' ': 50, 'g': 171, 'L': 745, 'M': 367, 'D': 383, 'd': 257, '$': 46, '5': 77, '4': 44, '.': 95, 'w': 97, 'B': 331, '1': 184, '7': 43, '8': 44, 'V': 158, 'y': 161, 'K': 163, '!': 51, '9': 66, 'z': 12, ';': 3, '#': 16, 'j': 15, "'": 51, 'J': 72, ':': 19, 'x': 27, '%': 28, '/': 24, 'q': 3, 'Q': 19, '(': 6, ')': 5, '\\': 8, '"': 8, '′': 3, 'Z': 29, '&': 9, 'é': 1, '@': 4, '=': 1, '+': 1}
训练集+验证集label中出现的字符:
{'[': 2, '0': 232, '6': 44, ']': 2, '2': 139, '-': 87, '3': 69, 'C': 893, 'a': 1200, 'r': 935, 'p': 317, 'k': 137, 'E': 2213, 'X': 181, 'I': 1241, 'T': 1315, 'R': 1262, 'f': 203, 'u': 415, 's': 793, 'i': 924, 'o': 954, 'n': 880, 'l': 555, 'e': 1534, 'v': 169, 'A': 1827, 'U': 467, 'O': 1440, 'N': 1158, 'c': 442, 't': 829, 'm': 278, 'W': 288, 'H': 593, 'Y': 341, 'P': 582, 'F': 402, 'G': 521, '?': 7, 'S': 1748, 'b': 129, 'h': 417, ' ': 82, 'g': 260, 'L': 1120, 'M': 536, 'D': 548, 'd': 367, '$': 57, '5': 100, '4': 53, '.': 132, 'w': 136, 'B': 468, '1': 228, '7': 60, '8': 51, 'V': 224, 'y': 231, 'K': 253, '!': 65, '9': 76, 'z': 14, ';': 3, '#': 24, 'j': 19, "'": 70, 'J': 100, ':': 24, 'x': 38, '%': 42, '/': 29, 'q': 3, 'Q': 28, '(': 7, ')': 5, '\\': 8, '"': 8, '′': 3, 'Z': 36, '&': 15, 'é': 2, '@': 9, '=': 1, '+': 2, 'é': 1}

上方代码中,lbl_cnt_map 为字符出现次数的统计字典,后面还会用于建立字符及其id映射关系。从数据集统计结果来看,测试集含有训练集没有出现过的字符,例如测试集中包含1个’é’未曾在训练集出现。这种情况数量不多,应该问题不大,所以此处未对数据集进行额外处理(但是有意识的进行这种训练集和测试集是否存在diff的检查是必要的)。

1.2char和id的映射字典构建

# 构造label中 字符--id 之间的映射
print("构造label中 字符--id之间的映射:")

lbl2id_map = dict()
lbl2id_map['?'] = 0 # padding标识符
lbl2id_map['■'] = 1 # 句子起始符
lbl2id_map['□'] = 2 # 句子结束符
#生成其余字符的id映射关系
cur_id = 3
for lbl in lbl_cnt_map.keys():
    lbl2id_map[lbl] = cur_id
    cur_id += 1
     
#保存 字符--id 之间的映射 到txt文件

with open(lbl2id_map_path, 'w', encoding='utf-8') as writer:
    for lbl in lbl2id_map.keys():
        cur_id = lbl2id_map[lbl]
        print (lbl, cur_id)
        line = lbl + '\t' + str(cur_id) + '\n'
        writer.write(line)
构造label中 字符--id之间的映射:
? 0
■ 1
□ 2
[ 3
0 4
6 5
] 6
2 7
- 8
3 9
C 10
a 11
r 12
p 13
k 14
E 15
X 16
I 17
T 18
R 19
f 20
u 21
s 22
i 23
o 24
n 25
l 26
e 27
v 28
A 29
U 30
O 31
N 32
c 33
t 34
m 35
W 36
H 37
Y 38
P 39
F 40
G 41
? 42
S 43
b 44
h 45
  46
g 47
L 48
M 49
D 50
d 51
$ 52
5 53
4 54
. 55
w 56
B 57
1 58
7 59
8 60
V 61
y 62
K 63
! 64
9 65
z 66
; 67
# 68
j 69
' 70
J 71
: 72
x 73
% 74
/ 75
q 76
Q 77
( 78
) 79
\ 80
" 81
′ 82
Z 83
& 84
é 85
@ 86
= 87
+ 88
é 89
def load_lbl2id_map(lbl2id_map_path):
    """
    读取 字符-id 映射关系记录的txt文件,并返回 lbl->id 和 id->lbl 映射字典
    lbl2id_map_path : 字符-id 映射关系记录的txt文件路径
    """
    lbl2id_map = dict()
    id2lbl_map = dict()
    with open(lbl2id_map_path, 'r',encoding = 'utf-8') as reader:
        for line in reader:
            items = line.rstrip().split('\t')
            label = items[0]
            cur_id = int(items[1])
            lbl2id_map[label] = cur_id
            id2lbl_map[cur_id] = label
    return lbl2id_map, id2lbl_map

1.3数据集图像尺寸分析

# 分析数据集图片尺寸
print("分析数据集图片尺寸:")
# 初始化参数
min_h = 1e10
min_w = 1e10
max_h = -1
max_w = -1
min_ratio = 1e10
max_ratio = 0
# 遍历数据集计算尺寸信息
for img_name in os.listdir(train_img_dir):
    img_path = os.path.join(train_img_dir,img_name)
    img = cv2.imread(img_path)
    h, w = img.shape[:2]
    ratio = w / h #高宽比
    min_h = min_h if min_h <= h else h # 最小图片高度
    max_h = max_h if max_h >= h else h # 最大图片高度
    min_w = min_w if min_w <= w else w # 最小图片宽度
    max_w = max_w if max_w >= w else w # 最大图片宽度
    min_ratio = min_ratio if min_ratio <= ratio else ratio # 最小图片高宽比
    max_ratio = max_ratio if max_ratio >= ratio else ratio # 最大图片高宽比
    # 输出信息
print('min_h:', min_h)
print('max_h:', max_h)
print('min_w:', min_w)
print('max_w:', max_w)
print('min_ratio:', min_ratio)
print('max_ratio:', max_ratio)
分析数据集图片尺寸:
min_h: 9
max_h: 295
min_w: 16
max_w: 628
min_ratio: 0.6666666666666666
max_ratio: 8.619047619047619

2.将transformer引入OCR

在这里插入图片描述

2.1准备工作

import os
import time
import copy
import numpy as np
from PIL import Image
# torch相关包
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms
# 导入工具类包
# from analysis_recognition_dataset import load_lbl2id_map, statistics_max_len_label
from transformer import *
from train_utils import *
base_data_dir = './ICDAR_2015/' # 数据集根目录,请将数据下载到此位置
device = torch.device('cuda') # 'cpu'或者'cuda'
nrof_epochs = 1500 # 迭代次数,1500,根据需求进行修正
batch_size = 16 # 批量大小,32,根据需求进行修正
model_save_path = './log/ex1_ocr_model.pth'
# 读取label-id映射关系记录文件
lbl2id_map_path = os.path.join(base_data_dir, 'lbl2id_map.txt')
lbl2id_map, id2lbl_map = load_lbl2id_map(lbl2id_map_path)
# 统计数据集中出现的所有的label中包含字符最多的有多少字符,数据集构造gt(ground truth)信息需要用到
train_lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
valid_lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
train_max_label_len = statistics_max_len_label(train_lbl_path)
valid_max_label_len = statistics_max_len_label(valid_lbl_path)
# 数据集中字符数最多的一个case作为制作的gt的sequence_len
sequence_len = max(train_max_label_len, valid_max_label_len)

2.2数据集创建

class Recognition_Dataset(object):

    def __init__(self, dataset_root_dir, lbl2id_map, sequence_len, max_ratio, phase='train', pad=0):

        if phase == 'train':
            self.img_dir = os.path.join(base_data_dir, 'train')
            self.lbl_path = os.path.join(base_data_dir, 'train_gt.txt')
        else:
            self.img_dir = os.path.join(base_data_dir, 'valid')
            self.lbl_path = os.path.join(base_data_dir, 'valid_gt.txt')
        self.lbl2id_map = lbl2id_map
        self.pad = pad   # padding标识符的id,默认0
        self.sequence_len = sequence_len    # 序列长度
        self.max_ratio = max_ratio * 3      # 将宽拉长3倍

        self.imgs_list = []
        self.lbls_list = []
        with open(self.lbl_path, 'r',encoding = 'utf-8') as reader:
            for line in reader:
                items = line.rstrip().split(',')
                img_name = items[0]
                lbl_str = items[1].strip()[1:-1]
                self.imgs_list.append(img_name)
                self.lbls_list.append(lbl_str)

        # 定义随机颜色变换
        self.color_trans = transforms.ColorJitter(0.1, 0.1, 0.1)
        # 定义 Normalize
        self.trans_Normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
        ]) 

    def __getitem__(self, index):
        """ 
        获取对应index的图像和ground truth label,并视情况进行数据增强
        """
        img_name = self.imgs_list[index]
        img_path = os.path.join(self.img_dir, img_name)
        lbl_str = self.lbls_list[index]

        # ----------------
        # 图片预处理
        # ----------------
        # load image
        img = Image.open(img_path).convert('RGB')

        # 对图片进行大致等比例的缩放
        # 将高缩放到32,宽大致等比例缩放,但要被32整除
        w, h = img.size
        ratio = round((w / h) * 3)   # 将宽拉长3倍,然后四舍五入
        if ratio == 0:
            ratio = 1
        if ratio > self.max_ratio:
            ratio = self.max_ratio
        h_new = 32
        w_new = h_new * ratio
        img_resize = img.resize((w_new, h_new), Image.BILINEAR)

        # 对图片右半边进行padding,使得宽/高比例固定=self.max_ratio
        img_padd = Image.new('RGB', (32*self.max_ratio, 32), (0,0,0))
        img_padd.paste(img_resize, (0, 0))

        # 随机颜色变换
        img_input = self.color_trans(img_padd)
        # Normalize
        img_input = self.trans_Normalize(img_input)

        # ----------------
        # label处理
        # ----------------

        # 构造encoder的mask
        encode_mask = [1] * ratio + [0] * (self.max_ratio - ratio)
        encode_mask = torch.tensor(encode_mask)
        encode_mask = (encode_mask != 0).unsqueeze(0)

        # 构造ground truth label
        gt = []
        gt.append(1)    # 先添加句子起始符
        for lbl in lbl_str:
            gt.append(self.lbl2id_map[lbl])
        gt.append(2)
        for i in range(len(lbl_str), self.sequence_len):   
        # 除去起始符终止符,lbl长度为sequence_len,剩下的padding
            gt.append(0)
        # 截断为预设的最大序列长度
        gt = gt[:self.sequence_len]

        # decoder的输入
        decode_in = gt[:-1]
        decode_in = torch.tensor(decode_in)
        # decoder的输出
        decode_out = gt[1:]
        decode_out = torch.tensor(decode_out)
        # decoder的mask 
        decode_mask = self.make_std_mask(decode_in, self.pad)
        # 有效tokens数
        ntokens = (decode_out != self.pad).data.sum()

        return img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens 
    
    @staticmethod
    def make_std_mask(tgt, pad):
        """
        Create a mask to hide padding and future words.
        padd 和 future words 均在mask中用0表示
        """
        tgt_mask = (tgt != pad)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        tgt_mask = tgt_mask.squeeze(0)   # subsequent返回值的shape是(1, N, N)
        return tgt_mask

    def __len__(self):
        return len(self.imgs_list)

以上是构建Dataset的所有细节,进而我们可以构建出DataLoader供训练使用

# 构造 dataloader
max_ratio = 8    # 图片预处理时 宽/高的最大值,不超过就保比例resize,超过会强行压缩
train_dataset = Recognition_Dataset(base_data_dir, lbl2id_map, sequence_len, max_ratio, 'train', pad=0)
valid_dataset = Recognition_Dataset(base_data_dir, lbl2id_map, sequence_len, max_ratio, 'valid', pad=0)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=0)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                        batch_size=batch_size,
                                        shuffle=False,
                                        num_workers=0)  

3.模型构建

代码通过 make_ocr_modelOCR_EncoderDecoder 类完成模型结构搭建。

make_ocr_model 这个函数看起,该函数首先调用了pytorch中预训练的Resnet-18作为backbone以提取图像特征,此处也可以根据自己需要调整为其他的网络,但需要重点关注的是网络的下采样倍数,以及最后一层特征图的channel_num,相关模块的参数需要同步调整。之后调用了OCR_EncoderDecoder 类完成transformer的搭建。最后对模型参数进行初始化。

OCR_EncoderDecoder 类中,该类相当于是一个transformer各基础组件的拼装线,包括 encoder和 decoder 等,其初始参数是已存在的基本组件,其基本组件代码都在transformer.py文件中,本文不过多赘述。

来回顾一下,图片经过backbone后,如何构造为Transformer的输入:
图片经过backbone后将输出一个维度为 [batch_size, 512, 1, 24] 的特征图,在不关注batch_size的前提下,每一张图像都会得到如下所示具有512个通道的1×24的特征图,如图中红色框标注所示,将不同通道相同位置的特征值拼接组成一个新的向量,并作为一个时间步的输入,此时变构造出了维度为[batch_size, 24, 512] 的输入,满足Transformer的输入要求。
在这里插入图片描述下面看完整的构造模型部分的代码:

# Model Architecture
class OCR_EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. 
    Base for this and many other models.
    """
    def __init__(self, encoder, decoder, src_embed, src_position, tgt_embed, generator):
        super(OCR_EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed    # input embedding module(input embedding + positional encode)
        self.src_position = src_position
        self.tgt_embed = tgt_embed    # ouput embedding module
        self.generator = generator    # output generation module
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        memory = self.encode(src, src_mask)
        res = self.decode(memory, src_mask, tgt, tgt_mask)
        return res
    
    def encode(self, src, src_mask):
        # feature extract
        src_embedds = self.src_embed(src)
        # 将src_embedds由shape(bs, model_dim, 1, max_ratio) 处理为transformer期望的输入shape(bs, 时间步, model_dim)
        src_embedds = src_embedds.squeeze(-2)
        src_embedds = src_embedds.permute(0, 2, 1)

        # position encode
        src_embedds = self.src_position(src_embedds)

        return self.encoder(src_embedds, src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        target_embedds = self.tgt_embed(tgt)
        return self.decoder(target_embedds, memory, src_mask, tgt_mask)


def make_ocr_model(tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """
    构建模型
    params:
        tgt_vocab: 输出的词典大小(82)
        N: 编码器和解码器堆叠基础模块的个数
        d_model: 模型中embedding的size,默认512
        d_ff: FeedForward Layer层中embedding的size,默认2048
        h: MultiHeadAttention中多头的个数,必须被d_model整除
        dropout:
    """
    c = copy.deepcopy

    backbone = models.resnet18(pretrained=True)
    backbone = nn.Sequential(*list(backbone.children())[:-2])    # 去掉最后两个层 (global average pooling and fc layer)

    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)

    model = OCR_EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        backbone,
        c(position),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))
    
    # Initialize parameters with Glorot / fan_avg.
    for child in model.children():
        if child is backbone:
            # 将backbone的权重设为不计算梯度
            for param in child.parameters():
                param.requires_grad = False
            # 预训练好的backbone不进行随机初始化,其余模块进行随机初始化
            continue
        for p in child.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

构建transformer模型:

# build model
# use transformer as ocr recognize model
tgt_vocab = len(lbl2id_map.keys())
d_model = 512
ocr_model = make_ocr_model(tgt_vocab, N=5, d_model=d_model, d_ff=2048, h=8, dropout=0.1)
ocr_model.to(device)

4.模型训练

模型训练之前,还需要定义模型评判准则、迭代优化器等。本实验在训练时,使用了标签平滑(label smoothing)、网络训练热身(warmup)等策略,以上策略的调用函数均在train_utils.py 文件中,此处不涉及以上两种方法的原理及代码实现。

label smoothing可以将原始的硬标签转化为软标签,从而增加模型的容错率,提升模型泛化能力。代码中 LabelSmoothing() 函数实现了label smoothing,同时内部使用了相对熵函数计算了预测值与真实值之间的损失。

warmup策略能够有效控制模型训练过程中的优化器学习率,自动化的实现模型学习率由小增大再逐渐下降的控制,帮助模型在训练时更加稳定,实现损失的快速收敛。代码中 NoamOpt() 函数实现了warmup控制,采用的Adam优化器,实现学习率随迭代次数的自动调整。

# train prepare
criterion = LabelSmoothing(size=tgt_vocab, padding_idx=0, smoothing=0.0)
#optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ocr_model.parameters()), 
#                            lr=0,
#                            betas=(0.9, 0.98),
#                            eps=1e-9)
optimizer = torch.optim.Adam(ocr_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
model_opt = NoamOpt(d_model, 1, 400, optimizer)

SimpleLossCompute() 类实现了transformer输出结果的loss计算。在使用该类直接计算时,类需要接收(x, y, norm) 三个参数, x 为decoder输出的结果, y 为标签数据, norm 为loss的归一化系数,用batch中所有有效token数即可。由此可见,此处才正完成transformer所有网络的构建,实现数据计算
流的流通。

class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        """
        norm: loss的归一化系数,用batch中所有有效token数即可
        """
        x = self.generator(x)
        x_ = x.contiguous().view(-1, x.size(-1))
        y_ = y.contiguous().view(-1)
        loss = self.criterion(x_, y_)
        loss /= norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        #return loss.data[0] * norm  # TODO
        return loss.item() * norm

模型训练过程的代码如下所示,每训练10个epoch便进行一次验证,单个epoch的计算过程封装在run_epoch() 函数中。

def run_epoch(data_loader, model, loss_compute, device=None):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_loader):
        #if device == "cuda":
        #    batch.to_device(device) 
        img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
        img_input = img_input.to(device)                        
        encode_mask = encode_mask.to(device)                                
        decode_in = decode_in.to(device)                                
        decode_out = decode_out.to(device)                    
        decode_mask = decode_mask.to(device)
        ntokens = torch.sum(ntokens).to(device)

        out = model.forward(img_input, decode_in, encode_mask, decode_mask)

        loss = loss_compute(out, decode_out, ntokens)
        total_loss += loss
        total_tokens += ntokens
        tokens += ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens
for epoch in range(nrof_epochs):
    print(f"\nepoch {epoch}")

    print("train...")
    ocr_model.train()
    loss_compute = SimpleLossCompute(ocr_model.generator, criterion, model_opt)
    train_mean_loss = run_epoch(train_loader, ocr_model, loss_compute, device)

    if epoch % 10 == 0:
        print("valid...")
        ocr_model.eval()
        valid_loss_compute = SimpleLossCompute(ocr_model.generator, criterion, None)
        valid_mean_loss = run_epoch(valid_loader, ocr_model, valid_loss_compute, device)
        print(f"valid loss: {valid_mean_loss}")
epoch 0
train...
Epoch Step: 1 Loss: 4.756953 Tokens per Sec: 74.231010
Epoch Step: 51 Loss: 3.345229 Tokens per Sec: 227.936249
Epoch Step: 101 Loss: 3.164185 Tokens per Sec: 217.443558
Epoch Step: 151 Loss: 2.884049 Tokens per Sec: 198.306046
Epoch Step: 201 Loss: 2.918671 Tokens per Sec: 204.400925
Epoch Step: 251 Loss: 3.152167 Tokens per Sec: 205.873077
valid...
Epoch Step: 1 Loss: 2.701383 Tokens per Sec: 275.244385
Epoch Step: 51 Loss: 2.951396 Tokens per Sec: 240.986679
Epoch Step: 101 Loss: 2.714810 Tokens per Sec: 261.232330
valid loss: 2.839085102081299

epoch 1
train...
Epoch Step: 1 Loss: 3.549314 Tokens per Sec: 193.934494
Epoch Step: 51 Loss: 2.953091 Tokens per Sec: 198.670242
Epoch Step: 101 Loss: 2.828863 Tokens per Sec: 214.964783
Epoch Step: 151 Loss: 2.756577 Tokens per Sec: 208.429001

5.贪心解码

我们使用最简单的贪心解码直接进行OCR结果预测。因为模型每一次只会产生一个输出,我们选择输出的概率分布中的最高概率对应的字符为本次预测的结果,然后预测下一个字符,这就是所谓的贪心解码,见代码中 greedy_decode() 函数。
实验中分别将每一张图像作为模型的输入,逐张进行贪心解码统计正确率,并最终给出了训练集和验证集各自的预测准确率。

# 训练结束,使用贪心的解码方式推理训练集和验证集,统计正确率
ocr_model.eval()
print("\n------------------------------------------------")
print("greedy decode trainset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(train_loader):
    img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
    img_input = img_input.to(device)                        
    encode_mask = encode_mask.to(device)                                

    bs = img_input.shape[0]
    for i in range(bs):
        cur_img_input = img_input[i].unsqueeze(0)
        cur_encode_mask = encode_mask[i].unsqueeze(0)
        cur_decode_out = decode_out[i]

        pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
        pred_result = pred_result.cpu()

        is_correct = judge_is_correct(pred_result, cur_decode_out)
        total_correct_num += is_correct
        total_img_num += 1
        if not is_correct:
            # 预测错误的case进行打印
            print("----")
            print(cur_decode_out)
            print(pred_result)
total_correct_rate = total_correct_num / total_img_num * 100
print(f"total correct rate of trainset: {total_correct_rate}%")

print("\n------------------------------------------------")
print("greedy decode validset")
total_img_num = 0
total_correct_num = 0
for batch_idx, batch in enumerate(valid_loader):
    img_input, encode_mask, decode_in, decode_out, decode_mask, ntokens = batch
    img_input = img_input.to(device)                        
    encode_mask = encode_mask.to(device)                                

    bs = img_input.shape[0]
    for i in range(bs):
        cur_img_input = img_input[i].unsqueeze(0)
        cur_encode_mask = encode_mask[i].unsqueeze(0)
        cur_decode_out = decode_out[i]

        pred_result = greedy_decode(ocr_model, cur_img_input, cur_encode_mask, max_len=sequence_len, start_symbol=1, end_symbol=2)
        pred_result = pred_result.cpu()

        is_correct = judge_is_correct(pred_result, cur_decode_out)
        total_correct_num += is_correct
        total_img_num += 1
        if not is_correct:
            # 预测错误的case进行打印
            print("----")
            print(cur_decode_out)
            print(pred_result)
total_correct_rate = total_correct_num / total_img_num * 100
print(f"total correct rate of validset: {total_correct_rate}%")

greedy_decode() 函数实现如下:

# greedy decode
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):
    memory = model.encode(src, src_mask)
    # ys代表目前已生成的序列,最初为仅包含一个起始符的序列,不断将预测结果追加到序列最后
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data).long()
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        next_word = torch.ones(1, 1).type_as(src.data).fill_(next_word).long()
        ys = torch.cat([ys, next_word], dim=1)

        next_word = int(next_word)
        if next_word == end_symbol:
            break
        #ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    ys = ys[0, 1:]
    return ys


def judge_is_correct(pred, label):
    # 判断模型预测结果和label是否一致
    pred_len = pred.shape[0]
    label = label[:pred_len]
    is_correct = 1 if label.equal(pred) else 0
    return is_correct
Epoch Step: 1 Loss: 5.315293 Tokens per Sec: 2073.354492
valid...
Epoch Step: 1 Loss: 3.870697 Tokens per Sec: 2173.835449
valid loss: 3.8293662071228027

epoch 1
train...
Epoch Step: 1 Loss: 3.892932 Tokens per Sec: 2160.098633

epoch 2
train...
Epoch Step: 1 Loss: 3.594534 Tokens per Sec: 2163.552490


tensor([56, 56, 56, 55, 62, 56, 47, 12, 24, 21, 13, 55, 33, 24, 35, 55, 22, 47,
         2,  0,  0,  0])
tensor([56, 56, 56, 55, 62, 56, 47, 12, 24, 21, 13, 55, 33, 24, 35, 55, 33, 24,
        34, 55])

tensor([15, 16, 10, 15, 39, 18,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0])
tensor([18, 31, 39,  2])

total correct rate of validset: 95.78313253012048%

6.总结

本文首先介绍了所使用的ICDAR2015中的一个单词识别任务数据集,然后对数据的特点进行了简单分析,并构建了识别用的字符映射关系表。之后,重点介绍了将transformer引入来解决OCR任务的动机与思路,并结合代码详细介绍了细节,最后大致过了一些训练相关的逻辑和代码。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-22 10:56:24  更:2021-10-22 10:57:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 8:29:06-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码