方法二:利用scikit-image库生成图像标签数据集
提示:此处独立使用图像库scikit-image。即仅用io读图和显示处理服装关键点数据集
安装OpenCV的时候,安装opencv_python: pip install scikit-image 导入的时候:from skimage import io, transform, draw
服装关键点数据集下载:链接:https://pan.baidu.com/s/1A_UEaulqsz60OhC5BStA9g?pwd=hr47 提取码:hr47
数据集描述:pytorch生成图像标签数据集的三种方式–前言
Skimage模块常用子模块
Skimage模块常用子模块: io 用于图像读取、保存,显示图片和视频。color 颜色空间变换。filters 包括图像增强、边缘检测、排序滤波、自动阈值。 draw 基于numpy数组图像绘制,线段、矩形、圆和文本。transform 几何变换包括:旋转,拉伸,收缩等非回调函数。 Exposure 曝光调整包括:强度、亮度、直方图均衡化。Feature特征检测与提取。 measure 图像属性测量:相似性、等高线。segmentation 图像分割。 restoration 图像恢复。
生成 图像-关键点坐标标签 数据集
此例,服装类型和关键点图像-标签数据集,引入 io, transform, draw 的函数模块进行处理。 数据集展示:(图像,坐标,类型)和只管图像显示。
代码:dataset_by_skimage.py
import os
import numpy as np
import pandas as pd
import torch
from skimage import io, transform, draw
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class KeyPointsDataSet(Dataset):
"""服装关键点标记数据集"""
def __init__(self, root_dir, image_set='train', transforms=None):
"""
初始化数据集
:param root_dir: 数据目录(.csv和images的根目录)
:param image_set: train训练,val验证,test测试
:param transforms(callable,optional):图像变换-可选
标签数据文件格式为csv_file: 标签csv文件(内容:图像相对地址-category类型-标签coordination坐标)
"""
self._imgset = image_set
self._image_paths = []
self._labels = []
self._cates = []
self._csv_file = os.path.join(root_dir, image_set + '.csv')
self._categories = ['blouse', 'outwear', 'dress', 'trousers', 'skirt', ]
self.root_dir = root_dir
self._transform = transforms
self.__getFileList()
def __len__(self):
return len(self._image_paths)
def __getitem__(self, idx):
img_id = self._image_paths[idx]
img_id = os.path.join(self.root_dir, img_id)
image = io.imread(img_id)
imgSize = image.shape[0:2]
label = np.asfortranarray(self._labels[idx])
category = self._categories.index(self._cates[idx])
if self._transform:
image = self._transform(image)
else:
image = transform.resize(image, output_shape=(256, 256))
afterSize = image.shape[0:2]
bi = np.array((afterSize[1], afterSize[0])) / np.array((imgSize[1], imgSize[0]))
label[:, 0:2] = label[:, 0:2] * bi
return image, label, category
def __getFileList(self):
file_info = pd.read_csv(self._csv_file)
self._image_paths = file_info.iloc[:, 0]
self._cates = file_info.iloc[:, 1]
if self._imgset == 'train':
landmarks = file_info.iloc[:, 2:26].values
for i in range(len(landmarks)):
label = []
for j in range(24):
plot = landmarks[i][j].split('_')
coor = []
for per in plot:
coor.append(int(per))
label.append(coor)
self._labels.append(np.concatenate(label))
self._labels = np.array(self._labels).reshape((-1, 24, 3))
else:
self._labels = np.ones((len(self._image_paths), 24, 3)) * (-1)
def showImageAndCoor(img, coords):
for coor in coords:
if coor[2] == -1:
pass
else:
rr, cc = draw.circle(coor[1], coor[0], 4)
draw.set_color(img, [rr, cc], [255, 0, 0])
io.imshow(img)
io.show()
if __name__ == "__main__":
fashionDataset = KeyPointsDataSet(root_dir=r"E:/Datasets/Fashion/Fashion AI-keypoints_24/train/",
image_set="train",
)
dataloader = DataLoader(dataset=fashionDataset, batch_size=4)
for i_batch, data in enumerate(dataloader):
img, label, category = data
img, label, category = img.numpy(), label.numpy(), category.numpy()
print(img.shape, label.shape, category)
showImageAndCoor(img[0], label[0])
注意事项
- io读图的数据结构也是为(h, w, c)=(高,宽,通道),坐标组是(宽x, 高y)。统一伸缩时注意对应。
- 本文输出数据集为了显示并没有对图像数组进行归一化或标准化操作,用的时候需要加上归一化。
|