由于实际项目中没有那么多类,所以就假设有10个种类,每个种类设置对应一种颜色
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
if __name__ == '__main__':
#
input_shape = [5,5]
img = np.random.randint(0,10,size=input_shape,dtype='uint8')
plt.subplot(1,2,1)
plt.title('pred')
plt.imshow(img,cmap='gray')
# print(img)
imgshow = np.zeros(input_shape+[3],dtype='uint8')
# 标签为5的位置rgb设置为(0,255,255)
label_color = {5:[0,255,255],4:[255,0,0],3:[0,255,0],2:[0,0,255],
1:[255,255,0],6:[255,0,255]}
for key in label_color.keys():
# r通道标签为5的位置,设置为0
imgshow[...,0][img==key] = label_color[key][0]
# g通道标签为5的位置,设置为255
imgshow[...,1][img==key] = label_color[key][1]
# b通道标签为5的位置,设置为255
imgshow[...,2][img==key] = label_color[key][2]
# image = Image.fromarray(imgshow)
# image.show()
plt.subplot(1,2,2)
plt.title('color')
plt.imshow(imgshow)
plt.show()
?
?可以将鼠标放上面进行验证
|