前面一篇文章介绍了quickdraw的前世今身:
宇宙最大的手绘草图数据集——QuickDraw 解析、下载、使用、训练、可视化(附完整代码)_沉迷单车的追风少年-CSDN博客
但是在生成的过程中,cpu利用率很低,大规模运用很难。举个例子,单类别10k个image生成居然花费了七天七夜,更别说想运用全部几百个类了。有必要从根本上解决效率问题啊!本文由浅入深,一步一步手把手带你彻底解决这一问题。同样的解决思路可以复用到其他深度学习任务当中。
目录
baseline任务
多线程——创建/销毁开销巨大,得不偿失
线程池——池化技术解决线程创建/销毁开销
为什么CPU利用率仍然这么低?
GIL全局解释器锁导致的Python多线程形如虚设
多进程并行计算大显身手
参考
baseline任务
上一篇文章中封装了一个class,直接复用即可。
import cv2
import os
from PIL import Image
import matplotlib
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
# from sketch_processing import draw_three
import numpy as np
import random
class DrawSketch(object):
def __init__(self):
pass
def scale_sketch(self, sketch, size=(448, 448)):
[_, _, h, w] = self.canvas_size_google(sketch)
if h >= w:
sketch_normalize = sketch / np.array([[h, h, 1]], dtype=np.float)
else:
sketch_normalize = sketch / np.array([[w, w, 1]], dtype=np.float)
sketch_rescale = sketch_normalize * np.array([[size[0], size[1], 1]], dtype=np.float)
return sketch_rescale.astype("int16")
def canvas_size_google(self, sketch):
"""
:param sketch: google sketch, quickDraw
:return: int list,[x, y, h, w]
"""
# get canvas size
vertical_sum = np.cumsum(sketch[1:], axis=0)
xmin, ymin, _ = np.min(vertical_sum, axis=0)
xmax, ymax, _ = np.max(vertical_sum, axis=0)
w = xmax - xmin
h = ymax - ymin
start_x = -xmin - sketch[0][0]
start_y = -ymin - sketch[0][1]
# sketch[0] = sketch[0] - sketch[0]
return [int(start_x), int(start_y), int(h), int(w)]
def draw_three(self, sketch, random_color=False, show=False, img_size=512):
"""
:param sketches: google quickDraw, (n, 3)
:param thickness: pass
:return: None
"""
# print("three ")
# print(sketch)
# print("-" * 70)
thickness = int(img_size * 0.025)
sketch = self.scale_sketch(sketch, (img_size, img_size)) # scale the sketch.
[start_x, start_y, h, w] = self.canvas_size_google(sketch=sketch)
start_x += thickness + 1
start_y += thickness + 1
canvas = np.ones((max(h, w) + 3 * (thickness + 1), max(h, w) + 3 * (thickness + 1), 3), dtype='uint8') * 255
if random_color:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = (0, 0, 0)
pen_now = np.array([start_x, start_y])
first_zero = False
for stroke in sketch:
delta_x_y = stroke[0:0 + 2]
state = stroke[2:]
if first_zero:
pen_now += delta_x_y
first_zero = False
continue
cv2.line(canvas, tuple(pen_now), tuple(pen_now + delta_x_y), color, thickness=thickness)
if int(state) == 1: # next stroke
first_zero = True
if random_color:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = (0, 0, 0)
pen_now += delta_x_y
if show:
key = cv2.waitKeyEx()
if key == 27: # esc
cv2.destroyAllWindows()
exit(0)
return cv2.resize(canvas, (img_size, img_size))
class SketchData(object):
def __init__(self, dataPath, model="train"):
self.dataPath = dataPath
self.model = model
# 加载数据
def load(self):
dataset_origin_list = []
category_list = self.getCategory()
for each_name in category_list:
# npz_test = np.load(f"./{self.dataPath}/{each_name}", encoding="latin1", allow_pickle=True)["test"]
npz_tmp = np.load(f"./{self.dataPath}/{each_name}", encoding="latin1", allow_pickle=True)[self.model]
print(f"dataset: {each_name} added.")
dataset_origin_list.append(npz_tmp)
return dataset_origin_list
# 获取类别列表
def getCategory(self):
category_list = os.listdir(self.dataPath)
return category_list
if __name__ == '__main__':
sketchdata = SketchData(dataPath='./dataset_npz')
category_list = sketchdata.getCategory()
dataset_origin_list = sketchdata.load()
# 作图
for category_index in range(len(category_list)):
sample_category_name = category_list[category_index]
print(sample_category_name)
save_name = sample_category_name.replace(".npz", "")
# 创建文件夹
folder = os.path.exists(f"./save_img/{save_name}/")
if not folder:
os.makedirs(f"./save_img/{save_name}/")
print(f"./save_img/{save_name}/ is new mkdir!")
drawsketch = DrawSketch()
# 作图
for image_index in range(10):
# sample_sketch = dataset_origin_list[sample_category_name.index(sample_category_name)][index]
sample_sketch = dataset_origin_list[category_list.index(sample_category_name)][image_index]
sketch_cv = drawsketch.draw_three(sample_sketch, True)
plt.xticks([]) # 去掉x轴
plt.yticks([]) # 去掉y轴
plt.axis('off') # 去掉坐标轴
plt.imshow(sketch_cv)
plt.savefig(f"./save_img/{save_name}/{image_index}.jpg")
print(f"{save_name}/{image_index}.jpg is saved!")
多线程——创建/销毁开销巨大,得不偿失
具体代码这里就不贴了,this is a bad idea。
在需要几个线程的时候,或许这是一个好方法,但是我们希望能创建几十个/几百个疯狂run,这样肯定不行。
线程池——池化技术解决线程创建/销毁开销
为了线程创建/销毁时候的开销问题,我们引入的线程池技术。相比于C++里面的池化技术,Python解释型语言确实开发效率高,很快就写完了。
但是注意我们需要一个全局锁,锁住队列,防止取的时候出现死锁的问题。大体的框架如下:
#!/usr/bin/python3
# 多线程生成草图image
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import threading
# 全局队列加锁
indexQueue = Queue(maxsize=10000)
queueLock = threading.Lock()
for i in range(0, 10000):
indexQueue.put(i)
def worker():
# 每一次取队列中没有画的sketch下标加锁
queueLock.acquire()
if not indexQueue.empty():
index = indexQueue.get()
else:
print("queue is empty")
queueLock.release()
print(f"thread write {index} image!")
# print(f'thread is over')
if __name__ == '__main__':
# 开一个线程池
with ThreadPoolExecutor(max_workers=1000) as t:
while not indexQueue.empty():
t.submit(worker)
exit()
完整的代码如下:
#!/usr/bin/python3
# 多线程生成草图image
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import threading
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import asyncio
class DrawSketch(object):
def __init__(self):
pass
def scale_sketch(self, sketch, size=(448, 448)):
[_, _, h, w] = self.canvas_size_google(sketch)
if h >= w:
sketch_normalize = sketch / np.array([[h, h, 1]], dtype=np.float)
else:
sketch_normalize = sketch / np.array([[w, w, 1]], dtype=np.float)
sketch_rescale = sketch_normalize * np.array([[size[0], size[1], 1]], dtype=np.float)
return sketch_rescale.astype("int16")
def canvas_size_google(self, sketch):
"""
:param sketch: google sketch, quickDraw
:return: int list,[x, y, h, w]
"""
# get canvas size
vertical_sum = np.cumsum(sketch[1:], axis=0)
xmin, ymin, _ = np.min(vertical_sum, axis=0)
xmax, ymax, _ = np.max(vertical_sum, axis=0)
w = xmax - xmin
h = ymax - ymin
start_x = -xmin - sketch[0][0]
start_y = -ymin - sketch[0][1]
# sketch[0] = sketch[0] - sketch[0]
return [int(start_x), int(start_y), int(h), int(w)]
def draw_three(self, sketch, random_color=False, show=False, img_size=512):
"""
:param sketches: google quickDraw, (n, 3)
:param thickness: pass
:return: None
"""
thickness = int(img_size * 0.025)
sketch = self.scale_sketch(sketch, (img_size, img_size)) # scale the sketch.
[start_x, start_y, h, w] = self.canvas_size_google(sketch=sketch)
start_x += thickness + 1
start_y += thickness + 1
canvas = np.ones((max(h, w) + 3 * (thickness + 1), max(h, w) + 3 * (thickness + 1), 3), dtype='uint8') * 255
if random_color:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = (0, 0, 0)
pen_now = np.array([start_x, start_y])
first_zero = False
for stroke in sketch:
delta_x_y = stroke[0:0 + 2]
state = stroke[2:]
if first_zero:
pen_now += delta_x_y
first_zero = False
continue
cv2.line(canvas, tuple(pen_now), tuple(pen_now + delta_x_y), color, thickness=thickness)
if int(state) == 1: # next stroke
first_zero = True
if random_color:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = (0, 0, 0)
pen_now += delta_x_y
if show:
key = cv2.waitKeyEx()
if key == 27: # esc
cv2.destroyAllWindows()
exit(0)
return cv2.resize(canvas, (img_size, img_size))
class SketchData(object):
def __init__(self, dataPath, model="train"):
self.dataPath = dataPath
self.model = model
# 加载数据
def load(self):
dataset_origin_list = []
category_list = self.getCategory()
for each_name in category_list:
# npz_test = np.load(f"./{self.dataPath}/{each_name}", encoding="latin1", allow_pickle=True)["test"]
npz_tmp = np.load(f"./{self.dataPath}/{each_name}", encoding="latin1", allow_pickle=True)[self.model]
print(f"dataset: {each_name} added.")
dataset_origin_list.append(npz_tmp)
return dataset_origin_list
# 获取类别列表
def getCategory(self):
category_list = os.listdir(self.dataPath)
return category_list
# 全局队列加锁
MAXQUEUESIZE = 10000
MAXTHREADSIZE = 100
indexQueue = Queue(maxsize=MAXQUEUESIZE)
queueLock = threading.Lock()
for i in range(0, MAXQUEUESIZE):
indexQueue.put(i)
def worker():
# 每一次取队列中没有画的sketch下标加锁
if not queueLock.acquire(blocking=False):
print(f"queueLock acquire is timeout!")
return
if not indexQueue.empty():
try:
image_index = indexQueue.get_nowait()
except:
# timeout return and release
print(f"queue get is timeout!")
queueLock.release()
return
else:
print("queue is empty")
queueLock.release()
sample_sketch = dataset_origin_list[category_list.index(sample_category_name)][image_index]
sketch_cv = drawsketch.draw_three(sample_sketch)
plt.xticks([]) # 去掉x轴
plt.yticks([]) # 去掉y轴
plt.axis('off') # 去掉坐标轴
plt.imshow(sketch_cv)
plt.savefig(f"./sketch_image/{save_name}/{save_name}_{image_index}.png")
print(f"{save_name}/{save_name}_{image_index}.png is saved!")
if __name__ == '__main__':
sketchdata = SketchData(dataPath='./sketch_dataset_airplane')
category_list = sketchdata.getCategory()
dataset_origin_list = sketchdata.load()
# 作图
for category_index in range(len(category_list)):
sample_category_name = category_list[category_index]
print(sample_category_name)
save_name = sample_category_name.replace(".npz", "")
# 创建文件夹
folder = os.path.exists(f"./sketch_image/{save_name}/")
if not folder:
os.makedirs(f"./sketch_image/{save_name}/")
print(f"./sketch_image/{save_name}/ is new mkdir!")
drawsketch = DrawSketch()
with ThreadPoolExecutor(max_workers=MAXTHREADSIZE) as t:
while not indexQueue.empty():
# t.shutdown(wait=False)
t.submit(worker)
exit()
为什么CPU利用率仍然这么低?
我以为能直接跑到90%,这样多舒服。但是仍然只有1%左右?
GIL全局解释器锁导致的Python多线程形如虚设
?这里就不得指出:python中的多线程其实并不是真正的多线程。
Python代码的执行由Python虚拟机(解释器)来控制。Python在设计之初就考虑要在主循环中,同时只有一个线程在执行,就像单CPU的系统中运行多个进程那样,内存中可以存放多个程序,但任意时刻,只有一个程序在CPU中运行。同样地,虽然Python解释器可以运行多个线程,只有一个线程在解释器中运行。
对Python虚拟机的访问由全局解释器锁(GIL)来控制,正是这个锁能保证同时只有一个线程在运行。在多线程环境中,Python虚拟机按照以下方式执行:
- 1.设置GIL。
- 2.切换到一个线程去执行。
- 3.运行。
- 4.把线程设置为睡眠状态。
- 5.解锁GIL。
- 6.再次重复以上步骤。
多进程并行计算大显身手
为了减少进程创建的开销,我们继续使用进程池技术来解决这一问题。
Python提供了非常好用的多进程包multiprocessing,只需要定义一个函数,Python会完成其他所有事情。借助这个包,可以轻松完成从单进程到并发执行的转换。multiprocessing支持子进程、通信和共享数据、执行不同形式的同步,提供了Process、Queue、Pipe、Lock等组件。
详细学习可以看下面这一篇博客:
Python多进程编程 - jihite - 博客园
我一下子创建128个进程,疯狂跑,舒服了。哈哈哈。但是注意需要创建阻塞进程,无锁操作。因为画图操作需要时间,如果是非阻塞,那么所有的进程都会直接返回。这样的无锁设计比之前的有锁操作完美了很多,具体代码如下:
#!/usr/bin/python
# 多线程生成草图image
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import threading
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import asyncio
import multiprocessing
class DrawSketch(object):
def __init__(self):
pass
def scale_sketch(self, sketch, size=(448, 448)):
[_, _, h, w] = self.canvas_size_google(sketch)
if h >= w:
sketch_normalize = sketch / np.array([[h, h, 1]], dtype=np.float)
else:
sketch_normalize = sketch / np.array([[w, w, 1]], dtype=np.float)
sketch_rescale = sketch_normalize * np.array([[size[0], size[1], 1]], dtype=np.float)
return sketch_rescale.astype("int16")
def canvas_size_google(self, sketch):
"""
:param sketch: google sketch, quickDraw
:return: int list,[x, y, h, w]
"""
# get canvas size
vertical_sum = np.cumsum(sketch[1:], axis=0)
xmin, ymin, _ = np.min(vertical_sum, axis=0)
xmax, ymax, _ = np.max(vertical_sum, axis=0)
w = xmax - xmin
h = ymax - ymin
start_x = -xmin - sketch[0][0]
start_y = -ymin - sketch[0][1]
# sketch[0] = sketch[0] - sketch[0]
return [int(start_x), int(start_y), int(h), int(w)]
def draw_three(self, sketch, random_color=False, show=False, img_size=512):
"""
:param sketches: google quickDraw, (n, 3)
:param thickness: pass
:return: None
"""
thickness = int(img_size * 0.025)
sketch = self.scale_sketch(sketch, (img_size, img_size)) # scale the sketch.
[start_x, start_y, h, w] = self.canvas_size_google(sketch=sketch)
start_x += thickness + 1
start_y += thickness + 1
canvas = np.ones((max(h, w) + 3 * (thickness + 1), max(h, w) + 3 * (thickness + 1), 3), dtype='uint8') * 255
if random_color:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = (0, 0, 0)
pen_now = np.array([start_x, start_y])
first_zero = False
for stroke in sketch:
delta_x_y = stroke[0:0 + 2]
state = stroke[2:]
if first_zero:
pen_now += delta_x_y
first_zero = False
continue
cv2.line(canvas, tuple(pen_now), tuple(pen_now + delta_x_y), color, thickness=thickness)
if int(state) == 1: # next stroke
first_zero = True
if random_color:
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
else:
color = (0, 0, 0)
pen_now += delta_x_y
if show:
key = cv2.waitKeyEx()
if key == 27: # esc
cv2.destroyAllWindows()
exit(0)
return cv2.resize(canvas, (img_size, img_size))
class SketchData(object):
def __init__(self, dataPath, model="train"):
self.dataPath = dataPath
self.model = model
# 加载数据
def load(self):
dataset_origin_list = []
category_list = self.getCategory()
for each_name in category_list:
# npz_test = np.load(f"./{self.dataPath}/{each_name}", encoding="latin1", allow_pickle=True)["test"]
npz_tmp = np.load(f"./{self.dataPath}/{each_name}", encoding="latin1", allow_pickle=True)[self.model]
print(f"dataset: {each_name} added.")
dataset_origin_list.append(npz_tmp)
return dataset_origin_list
# 获取类别列表
def getCategory(self):
category_list = os.listdir(self.dataPath)
return category_list
MAXQUEUESIZE = 10000
MAXTHREADSIZE = 128
drawsketch = DrawSketch()
def func(image_index, sample_category_name, save_name):
sample_sketch = dataset_origin_list[category_list.index(sample_category_name)][image_index]
sketch_cv = drawsketch.draw_three(sample_sketch)
plt.xticks([]) # 去掉x轴
plt.yticks([]) # 去掉y轴
plt.axis('off') # 去掉坐标轴
plt.imshow(sketch_cv)
plt.savefig(f"./sketch_image/{save_name}/{save_name}_{image_index}.png")
print(f"{save_name}/{save_name}_{image_index}.png is saved!")
if __name__ == '__main__':
sketchdata = SketchData(dataPath='./sketch_dataset_17')
category_list = sketchdata.getCategory()
dataset_origin_list = sketchdata.load()
pool = multiprocessing.Pool(processes=MAXTHREADSIZE)
# 作图
for category_index in range(len(category_list)):
sample_category_name = category_list[category_index]
print(sample_category_name)
save_name = sample_category_name.replace(".npz", "")
# 创建文件夹
folder = os.path.exists(f"./sketch_image/{save_name}/")
if not folder:
os.makedirs(f"./sketch_image/{save_name}/")
print(f"./sketch_image/{save_name}/ is new mkdir!")
for i in range(0, MAXQUEUESIZE):
# 维持执行的进程总数为processes,当一个进程执行完毕后会添加新的进程进去
# pool.apply_async(func, (i,)) # 非阻塞
pool.apply(func, (i, sample_category_name, save_name)) # 阻塞
pool.close()
pool.join() # 调用join之前,先调用close函数,否则会出错。执行完close后不会有新的进程加入到pool,join函数等待所有子进程结束
print(f"all process is end! save path is ./sketch_image/{save_name}/, category_list is {category_list}")
exit()
'''
# 全局队列加锁
MAXQUEUESIZE = 1000
MAXTHREADSIZE = 10
indexQueue = Queue(maxsize=MAXQUEUESIZE)
queueLock = threading.Lock()
for i in range(0, MAXQUEUESIZE):
indexQueue.put(i)
def worker():
# 每一次取队列中没有画的sketch下标加锁
if not queueLock.acquire(blocking=False):
print(f"queueLock acquire is timeout!")
return
if not indexQueue.empty():
try:
image_index = indexQueue.get_nowait()
except:
# timeout return and release
print(f"queue get is timeout!")
queueLock.release()
return
else:
print("queue is empty")
queueLock.release()
sample_sketch = dataset_origin_list[category_list.index(sample_category_name)][image_index]
sketch_cv = drawsketch.draw_three(sample_sketch)
plt.xticks([]) # 去掉x轴
plt.yticks([]) # 去掉y轴
plt.axis('off') # 去掉坐标轴
plt.imshow(sketch_cv)
plt.savefig(f"./sketch_image/{save_name}/{save_name}_{image_index}.png")
print(f"{save_name}/{save_name}_{image_index}.png is saved!")
if __name__ == '__main__':
sketchdata = SketchData(dataPath='./sketch_dataset_airplane')
category_list = sketchdata.getCategory()
dataset_origin_list = sketchdata.load()
# 作图
for category_index in range(len(category_list)):
sample_category_name = category_list[category_index]
print(sample_category_name)
save_name = sample_category_name.replace(".npz", "")
# 创建文件夹
folder = os.path.exists(f"./sketch_image/{save_name}/")
if not folder:
os.makedirs(f"./sketch_image/{save_name}/")
print(f"./sketch_image/{save_name}/ is new mkdir!")
drawsketch = DrawSketch()
with ThreadPoolExecutor(max_workers=MAXTHREADSIZE) as t:
while not indexQueue.empty():
# t.shutdown(wait=False)
t.submit(worker)
exit()
'''
好了,今天的分享就到这里结束了。下面就是以跑数据为理由的摸鱼时间啦哈哈哈~
参考
|