1、制作lmdb数据集
由于数据集制作不涉及核心,我就直接引用了python2的代码,当然有时间的同学们也可以自己改成python3的。 先上环境,我使用anaconda新建了一个py27的环境(注意:opencv最后一次对python2的维护止于python2.7!!!) 这是我的pip list:
Package Version
------------- -------------------
certifi 2020.6.20
lmdb 1.3.0
numpy 1.16.6
opencv-python 4.2.0.32
pip 19.3.1
setuptools 44.0.0.post20200106
wheel 0.37.1
接着上数据集: 关于数据集的标注,大家可使用labelimg,标签写成对应的内容就行,制作好数据集后,根据对应的坐标一次将子图抠出来在生成对应的txt文件标签就ok了。 拿到数据集后,确保txt和图片在一个文件夹内,使用python2.7运行以下代码:
# -*- coding: utf-8 -*-
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
#from genLineText import GenTextImage
def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
if img is None:
return False
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.iteritems():
txn.put(k, v)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
#print (len(imagePathList) , len(labelList))
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
print '...................'
# map_size=1099511627776 定义最大空间是1TB
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
for i in xrange(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'r') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
########## .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label
##########
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
def read_text(path):
with open(path) as f:
text = f.read()
text = text.strip()
return text
import glob
if __name__ == '__main__':
#lmdb 输出目录
outputPath = './lmdb3_sample'
# 训练图片路径,标签是txt格式,名字跟图片名字要一致,如123.jpg对应标签需要是123.txt
path = './data_sample/*.png'
imagePathList = glob.glob(path)
print '------------',len(imagePathList),'------------'
imgLabelLists = []
for p in imagePathList:
try:
imgLabelLists.append((p,read_text(p.replace('.png','.txt'))))
except:
continue
#imgLabelList = [ (p,read_text(p.replace('.jpg','.txt'))) for p in imagePathList]
##sort by lebelList
imgLabelList = sorted(imgLabelLists,key = lambda x:len(x[1]))
imgPaths = [ p[0] for p in imgLabelList]
txtLists = [ p[1] for p in imgLabelList]
createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)
运行完后,生成如下文件:
接下来就可以训练了
2、训练
可以从/meijieru/crnn.pytorch顺手拿一份代码,拿完后记得修改,不然我们使用的高版本的cuda、torch等会报错的,主要更新掉了variable以及部分其他bug,当然有一部分bug是参考这位小哥改的:
https://www.cnblogs.com/yanghailin/p/14519525.html
问题不大
1、train.py
from __future__ import print_function
from __future__ import division
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
import os
import utils
import dataset
import models.crnn as crnn
parser = argparse.ArgumentParser()
parser.add_argument('--trainRoot', default="./data/lmdb/", help='path to dataset')
parser.add_argument('--valRoot', default="./data/lmdb/", help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=5, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for')
# TODO(meijieru): epoch -> iter
parser.add_argument('--cuda', default=True, help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--pretrained', default='', help="path to pretrained model (to continue training)")
parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')
parser.add_argument('--expr_dir', default='expr', help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=100, help='Interval to be displayed')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, not used by adadealta')
parser.add_argument('--beta1', type=float, default=0.5, help='betaVariable1 for adam. default=0.5')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', default=True, help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--manualSeed', type=int, default=1234, help='reproduce experiemnt')
parser.add_argument('--random_sample',default=0, action='store_true', help='whether to sample the dataset with random sampler')
opt = parser.parse_args()
print(opt)
if not os.path.exists(opt.expr_dir):
os.makedirs(opt.expr_dir)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
train_dataset = dataset.lmdbDataset(root=opt.trainRoot)
assert train_dataset
if opt.random_sample :
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
else:
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=True, sampler=None,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
Val_dataset = dataset.lmdbDataset(
root=opt.valRoot, transform=dataset.resizeNormalize((100, 32)))
Valdata_loader = torch.utils.data.DataLoader(
Val_dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
nclass = len(opt.alphabet) + 1
nc = 1
converter = utils.strLabelConverter(opt.alphabet)
criterion = torch.nn.CTCLoss()
# custom weights initialization called on crnn
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.pretrained != '':
import collections
print('loading pretrained model from %s' % opt.pretrained)
load_model_ = torch.load(opt.pretrained)
state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():
name = k[7:] # remove `module.`
state_dict_rename[name] = v
crnn.load_state_dict(state_dict_rename)
print(crnn)
image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgW)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)
if opt.cuda:
crnn.cuda()
crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
image = image.cuda()
criterion = criterion.cuda()
# loss averager
loss_avg = utils.averager()
# setup optimizer
if opt.adam:
optimizer = optim.Adam(crnn.parameters(), lr=opt.lr,
betas=(opt.beta1, 0.999))
elif opt.adadelta:
optimizer = optim.Adadelta(crnn.parameters())
else:
optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)
def val(net, Valdata_loader, criterion, max_iter=100):
print('Start val')
net.eval()
val_iter = iter(Valdata_loader)
n_correct = 0
loss_avg = utils.averager()
max_iter = min(max_iter, len(Valdata_loader))
with torch.no_grad():
for i in range(max_iter):
data = val_iter.next()
i += 1
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
cost = criterion(preds, text, preds_size ,length) / batch_size
loss_avg.add(cost)
_, preds = preds.max(2)
preds = preds.squeeze(1)
preds = preds.transpose(1, 0).contiguous().view(-1)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
for pred, target in zip(sim_preds, cpu_texts):
if pred == target.lower():
n_correct += 1
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
accuracy = n_correct / float(max_iter * opt.batchSize)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
def trainBatch(net, criterion, optimizer):
data = train_iter.next()
cpu_images, cpu_texts = data
batch_size = cpu_images.size(0)
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts)
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
cost = criterion(preds, text, preds_size, length) / batch_size
crnn.zero_grad()
cost.backward()
optimizer.step()
return cost
for epoch in range(opt.nepoch):
train_iter = iter(train_loader)
i = 0
while i < len(train_loader):
for p in crnn.parameters():
p.requires_grad = True
crnn.train()
cost = trainBatch(crnn, criterion, optimizer)
loss_avg.add(cost)
i += 1
if i % opt.displayInterval == 0:
print('[%d/%d][%d/%d] Loss: %f' %
(epoch, opt.nepoch, i, len(train_loader), loss_avg.val()))
loss_avg.reset()
if i % opt.valInterval == 0 :
val(crnn, Valdata_loader, criterion)
# do checkpointing
if i % opt.saveInterval == 0:
torch.save(
crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.expr_dir, epoch, i))
if (0 != epoch) and (epoch % 100 ==0) and (1 == i):
torch.save(
crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.expr_dir, epoch, i))
2、dataset.py
#!/usr/bin/python
# encoding: utf-8
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import sampler
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import numpy as np
class lmdbDataset(Dataset):
def __init__(self, root=None, transform=None, target_transform=None):
self.env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot creat lmdb from %s' % (root))
sys.exit(0)
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
print("nSamples===================",nSamples)
self.nSamples = nSamples
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key.encode())
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label_key = 'label-%09d' % index
label_byte = txn.get(label_key.encode()) ################33
label = label_byte.decode()
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
img.sub_(0.5).div_(0.5)
return img
class randomSequentialSampler(sampler.Sampler):
def __init__(self, data_source, batch_size):
self.num_samples = len(data_source)
self.batch_size = batch_size
def __iter__(self):
n_batch = len(self) // self.batch_size
tail = len(self) % self.batch_size
index = torch.LongTensor(len(self)).fill_(0)
for i in range(n_batch):
random_start = random.randint(0, len(self) - self.batch_size)
# batch_index = random_start + torch.range(0, self.batch_size - 1)
batch_index = random_start + torch.range(0, self.batch_size - 1)
index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
# deal with tail
if tail:
random_start = random.randint(0, len(self) - self.batch_size)
tail_index = random_start + torch.range(0, tail - 1)
index[(i + 1) * self.batch_size:] = tail_index
return iter(index)
def __len__(self):
return self.num_samples
class alignCollate(object):
def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
def __call__(self, batch):
images, labels = zip(*batch)
imgH = self.imgH
imgW = self.imgW
if self.keep_ratio:
ratios = []
for image in images:
w, h = image.size
ratios.append(w / float(h))
ratios.sort()
max_ratio = ratios[-1]
imgW = int(np.floor(max_ratio * imgH))
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
transform = resizeNormalize((imgW, imgH))
images = [transform(image) for image in images]
images = torch.cat([t.unsqueeze(0) for t in images], 0)
return images, labels
3、utils.py
#!/usr/bin/python
# encoding: utf-8
import time
import torch
import torch.nn as nn
import collections
class strLabelConverter(object):
"""Convert between str and label.
NOTE:
Insert `blank` to the alphabet for CTC.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, ignore_case=True):
self._ignore_case = ignore_case
if self._ignore_case:
alphabet = alphabet.lower()
self.alphabet = alphabet + '-' # for `-1` index
self.dict = {}
for i, char in enumerate(alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[char] = i + 1
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = [
self.dict[char.lower() if self._ignore_case else char]
for char in text
]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s) for s in text]
text = ''.join(text)
text, _ = self.encode(text)
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, t, length, raw=False):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i - 1] for i in t])
else:
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(self.alphabet[t[i] - 1])
return ''.join(char_list)
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.IntTensor([l]), raw=raw))
index += l
return texts
class averager(object):
"""Compute average"""
def __init__(self):
self.n_count = 0
self.sum = 0
self.reset()
def add(self, v):
count = v.numel()
v = v.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res=0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
def oneHot(v, v_length, nc):
batchSize = v_length.size(0)
maxLength = v_length.max()
v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0)
acc = 0
for i in range(batchSize):
length = v_length[i]
label = v[acc:acc + length].view(-1, 1).long()
v_onehot[i, :length].scatter_(1, label, 1.0)
acc += length
return v_onehot
def loadData(v, data):
v.resize_(data.size()).copy_(data)#v.data.resize_(data.size()).copy_(data)
def prettyPrint(v):
print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type()))
print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0],
v.mean().data[0]))
def assureRatio(img):
"""Ensure imgH <= imgW."""
b, c, h, w = img.size()
if h > w:
main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None)
img = main(img)
return img
3、测试
建一个test1的文件夹,里面放上你的训练集图片,嘿嘿,也可放没训练的,测试代码如下:
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
import matplotlib.pyplot as plt
import collections
import os
import models.crnn as crnn
model_path = 'netCRNN_1400_1.pth'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
dir_img = "./test1/"
nclass = len(alphabet) + 1
model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
model = model.cuda()
load_model_ = torch.load(model_path)
state_dict_rename = collections.OrderedDict()
for k, v in load_model_.items():
name = k[7:] # remove `module.`
state_dict_rename[name] = v
print('loading pretrained model from %s' % model_path)
model.load_state_dict(state_dict_rename)
converter = utils.strLabelConverter(alphabet)
transformer = dataset.resizeNormalize((100, 32))
list_img = os.listdir(dir_img)
for cnt,img_name in enumerate(list_img):
print(cnt,img_name)
path_img = dir_img + img_name
image = Image.open(path_img).convert('L')
image = transformer(image)
if torch.cuda.is_available():
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('%-20s => %-20s' % (raw_pred, sim_pred))
image_show = Image.open(path_img)
plt.figure("show")
plt.imshow(image_show)
plt.show()
以上全部结束,接下来就是tensorrtx的加速的了,大家可以自行参考trtx
|