本博客是参考博主:K同学啊的博客。 CNN的部分与前面没有太大的区别,但是本博客对于博主来说,最大的收获就是数据处理那一部分的操作。
1.导入库
import tensorflow as tf
import matplotlib.pyplot as plt
import os,PIL,random,pathlib
import numpy as np
from tensorflow.keras import datasets, layers, models
2.数据导入
data_dir = "E:/tmp/.keras/datasets/yzm_photos/captcha"
data_dir = pathlib.Path(data_dir)
all_images_paths = list(data_dir.glob('*'))
all_images_paths = [str(path) for path in all_images_paths]
random.shuffle(all_images_paths)
all_label_names = [path.split("\\")[6].split(".")[0] for path in all_images_paths]
查看数据:
3.数字化
number = ['0','1','2','3','4','5','6','7','8','9']
alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z']
char_set = number+alphabet
char_set_len = len(char_set)
label_name_len = len(all_label_names[0])
def text2vec(text):
vector = np.zeros([label_name_len,char_set_len])
for i,c in enumerate(text):
idx = char_set.index(c)
vector[i][idx]=1.0
return vector
all_labels = [text2vec(i) for i in all_label_names]
4.构建一个tf.data.Dataset
def preprocess_image(image):
image = tf.image.decode_jpeg(image,channels=1)
image = tf.image.resize(image,[50,200])
return image/255.0
def load_and_preprocess_image(path):
image = tf.io.read_file(path)
return preprocess_image(image)
path_ds = tf.data.Dataset.from_tensor_slices(all_images_paths)
image_ds = path_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(all_labels)
image_label_ds = tf.data.Dataset.zip((image_ds,label_ds))
train_ds = image_label_ds.take(1000)
test_ds = image_label_ds.skip(1000)
5.搭建CNN网络
batch_size = 16
epochs = 20
train_ds = train_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)
model = models.Sequential([
tf.keras.layers.Conv2D(32,(3,3),activation='relu',input_shape=(50,200,1)),
tf.keras.layers.MaxPooling2D((2,2)),
tf.keras.layers.Conv2D(64,(3,3),activation='relu'),
tf.keras.layers.MaxPooling2D((2,2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1000),
tf.keras.layers.Dense(label_name_len*char_set_len),
tf.keras.layers.Reshape([label_name_len,char_set_len]),
tf.keras.layers.Softmax()
])
model.compile(optimizer="adam",
loss='categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(
train_ds,
validation_data=test_ds,
epochs=epochs
)
运行结果如图所示: 要想进一步优化实验结果,可以去调整某些参数,这个过程是比较漫长的。
6.模型加载&&预测
模型保存
model.save('E:/tmp/.keras/datasets/yzm_photos/yzm_model.h5')
模型加载
new_model = tf.keras.models.load_model('E:/tmp/.keras/datasets/yzm_photos/yzm_model.h5')
预测
plt.figure(figsize=(10,8))
for images, labels in test_ds.take(1):
images=tf.squeeze(images,axis=3)
for i in range(1):
ax = plt.subplot(5, 3, i + 1)
plt.imshow(images[i])
img_array = tf.expand_dims(images[i], 0)
img_array = tf.expand_dims(img_array, -1)
pre = new_model.predict(img_array)
plt.title(vec2text(np.argmax(pre, axis=2)[0]))
plt.show()
还是比较准确的,但是也有某些是识别不准确的,需要进一步去优化。 努力加油a啊
|